Compare commits

...

9 Commits

Author SHA1 Message Date
David Mak 2c9b1f5330 meta: Update README to include info regarding pre-commit hooks 2024-06-12 16:06:41 +08:00
David Mak bd13630f9a meta: Add pre-commit configuration 2024-06-12 16:06:41 +08:00
David Mak 467ce051ec flake: Add pre-commit to dev environment 2024-06-12 16:06:41 +08:00
David Mak f78a0ca8ee meta: Restrict number of allowed lints 2024-06-12 16:06:41 +08:00
David Mak d151ed48a7 meta: Set clippy lints in {main,lib}.rs
So that this does not have to be manually passed to the `cargo clippy`
command-line every single time. Also allows incrementally addressing
these lints by removing and fixing them one-by-one.
2024-06-12 16:06:41 +08:00
David Mak ccbd4bfe55 meta: Add RustRover configuration files 2024-06-12 16:06:41 +08:00
lyken c4420e6ab9 core: refactor `get_builtins()` 2024-06-12 15:09:20 +08:00
lyken fd36f78005 core: refactor `PrimitiveDefinitionId` into enum `PrimDef` 2024-06-12 15:01:01 +08:00
lyken 8168692cc3 apply cargo fmt 2024-06-12 14:45:03 +08:00
76 changed files with 7011 additions and 7323 deletions

View File

@ -1 +1 @@
doc-valid-idents = ["NumPy", ".."] doc-valid-idents = ["CPython", "NumPy", ".."]

6
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
# Default ignored files
/shelf/
/workspace.xml
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

View File

@ -0,0 +1,8 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="ClangTidy" enabled="true" level="WARNING" enabled_by_default="true">
<option name="clangTidyChecks" value="-*,cert-dcl21-cpp,cert-dcl58-cpp,cert-err34-c,cert-err52-cpp,cert-err60-cpp,cert-flp30-c,cert-msc50-cpp,cert-msc51-cpp,cert-str34-c,google-default-arguments,google-explicit-constructor,google-runtime-operator,hicpp-exception-baseclass,hicpp-multiway-paths-covered,misc-misplaced-const,misc-new-delete-overloads,misc-no-recursion,misc-non-copyable-objects,misc-throw-by-value-catch-by-reference,misc-unconventional-assign-operator,misc-uniqueptr-reset-release,mpi-buffer-deref,mpi-type-mismatch,openmp-use-default-none,portability-simd-intrinsics,bugprone-*,cppcoreguidelines-*,modernize-*,performance-*" />
</inspection_tool>
</profile>
</component>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/nac3.iml" filepath="$PROJECT_DIR$/.idea/nac3.iml" />
</modules>
</component>
</project>

18
.idea/nac3.iml Normal file
View File

@ -0,0 +1,18 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="CPP_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/nac3artiq/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3ast/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3core/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3ld/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3parser/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3standalone/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/runkernel/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/nac3standalone/demo/lib/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Build All (Debug)" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="build" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<envs />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Build All (Release)" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="build --release" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<envs />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Build Standalone (Debug)" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="build --bin nac3standalone" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<envs />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Build Standalone (Release)" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="build --release --bin nac3standalone" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<envs />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Clean" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="clean " />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<envs />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,17 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Nix Build (nac3artiq)" type="ShConfigurationType">
<option name="SCRIPT_TEXT" value="nix build -L .#packages.x86_64-linux.nac3artiq" />
<option name="INDEPENDENT_SCRIPT_PATH" value="true" />
<option name="SCRIPT_PATH" value="" />
<option name="SCRIPT_OPTIONS" value="" />
<option name="INDEPENDENT_SCRIPT_WORKING_DIRECTORY" value="true" />
<option name="SCRIPT_WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="INDEPENDENT_INTERPRETER_PATH" value="true" />
<option name="INTERPRETER_PATH" value="/bin/sh" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="EXECUTE_IN_TERMINAL" value="false" />
<option name="EXECUTE_SCRIPT_FILE" value="false" />
<envs />
<method v="2" />
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="Test" type="CargoCommandRunConfiguration" factoryName="Cargo Command">
<option name="command" value="test" />
<option name="workingDirectory" value="file://$PROJECT_DIR$" />
<option name="emulateTerminal" value="false" />
<option name="channel" value="DEFAULT" />
<option name="requiredFeatures" value="true" />
<option name="allFeatures" value="false" />
<option name="withSudo" value="false" />
<option name="buildTarget" value="REMOTE" />
<option name="backtrace" value="SHORT" />
<envs />
<option name="isRedirectInput" value="false" />
<option name="redirectInputPath" value="" />
<method v="2">
<option name="CARGO.BUILD_TASK_PROVIDER" enabled="true" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,21 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="check_demos.sh (Debug, -O0)" type="ShConfigurationType">
<option name="SCRIPT_TEXT" value="" />
<option name="INDEPENDENT_SCRIPT_PATH" value="true" />
<option name="SCRIPT_PATH" value="$PROJECT_DIR$/nac3standalone/demo/check_demos.sh" />
<option name="SCRIPT_OPTIONS" value="--debug -- -O0" />
<option name="INDEPENDENT_SCRIPT_WORKING_DIRECTORY" value="true" />
<option name="SCRIPT_WORKING_DIRECTORY" value="$PROJECT_DIR$/nac3standalone/demo" />
<option name="INDEPENDENT_INTERPRETER_PATH" value="true" />
<option name="INTERPRETER_PATH" value="/bin/bash" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="EXECUTE_IN_TERMINAL" value="false" />
<option name="EXECUTE_SCRIPT_FILE" value="true" />
<envs>
<env name="RUST_BACKTRACE" value="1" />
</envs>
<method v="2">
<option name="RunConfigurationTask" enabled="true" run_configuration_name="Build Standalone (Debug)" run_configuration_type="CargoCommandRunConfiguration" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="check_demos.sh (Debug, -O2)" type="ShConfigurationType">
<option name="SCRIPT_TEXT" value="" />
<option name="INDEPENDENT_SCRIPT_PATH" value="true" />
<option name="SCRIPT_PATH" value="$PROJECT_DIR$/nac3standalone/demo/check_demos.sh" />
<option name="SCRIPT_OPTIONS" value="--debug" />
<option name="INDEPENDENT_SCRIPT_WORKING_DIRECTORY" value="true" />
<option name="SCRIPT_WORKING_DIRECTORY" value="$PROJECT_DIR$/nac3standalone/demo" />
<option name="INDEPENDENT_INTERPRETER_PATH" value="true" />
<option name="INTERPRETER_PATH" value="/bin/bash" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="EXECUTE_IN_TERMINAL" value="false" />
<option name="EXECUTE_SCRIPT_FILE" value="true" />
<envs />
<method v="2">
<option name="RunConfigurationTask" enabled="true" run_configuration_name="Build Standalone (Debug)" run_configuration_type="CargoCommandRunConfiguration" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="check_demos.sh (Release, -O0)" type="ShConfigurationType">
<option name="SCRIPT_TEXT" value="" />
<option name="INDEPENDENT_SCRIPT_PATH" value="true" />
<option name="SCRIPT_PATH" value="$PROJECT_DIR$/nac3standalone/demo/check_demos.sh" />
<option name="SCRIPT_OPTIONS" value="-- -O0" />
<option name="INDEPENDENT_SCRIPT_WORKING_DIRECTORY" value="true" />
<option name="SCRIPT_WORKING_DIRECTORY" value="$PROJECT_DIR$/nac3standalone/demo" />
<option name="INDEPENDENT_INTERPRETER_PATH" value="true" />
<option name="INTERPRETER_PATH" value="/bin/bash" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="EXECUTE_IN_TERMINAL" value="false" />
<option name="EXECUTE_SCRIPT_FILE" value="true" />
<envs />
<method v="2">
<option name="RunConfigurationTask" enabled="true" run_configuration_name="Build Standalone (Release)" run_configuration_type="CargoCommandRunConfiguration" />
</method>
</configuration>
</component>

View File

@ -0,0 +1,19 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="check_demos.sh (Release, -O2)" type="ShConfigurationType">
<option name="SCRIPT_TEXT" value="" />
<option name="INDEPENDENT_SCRIPT_PATH" value="true" />
<option name="SCRIPT_PATH" value="$PROJECT_DIR$/nac3standalone/demo/check_demos.sh" />
<option name="SCRIPT_OPTIONS" value="" />
<option name="INDEPENDENT_SCRIPT_WORKING_DIRECTORY" value="true" />
<option name="SCRIPT_WORKING_DIRECTORY" value="$PROJECT_DIR$/nac3standalone/demo" />
<option name="INDEPENDENT_INTERPRETER_PATH" value="true" />
<option name="INTERPRETER_PATH" value="/bin/bash" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="EXECUTE_IN_TERMINAL" value="false" />
<option name="EXECUTE_SCRIPT_FILE" value="true" />
<envs />
<method v="2">
<option name="RunConfigurationTask" enabled="true" run_configuration_name="Build Standalone (Release)" run_configuration_type="CargoCommandRunConfiguration" />
</method>
</configuration>
</component>

13
.idea/vcs.xml Normal file
View File

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="CommitMessageInspectionProfile">
<profile version="1.0">
<inspection_tool class="BodyLimit" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="SubjectBodySeparation" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="SubjectLimit" enabled="true" level="WARNING" enabled_by_default="true" />
</profile>
</component>
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

24
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,24 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
default_stages: [commit]
repos:
- repo: local
hooks:
- id: nac3-cargo-fmt
name: nac3 cargo format
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo fmt on the codebase.
args: [fmt]
- id: nac3-cargo-clippy
name: nac3 cargo clippy
entry: cargo
language: system
types: [file, rust]
pass_filenames: false
description: Runs cargo clippy on the codebase.
args: [clippy]

21
Cargo.lock generated
View File

@ -625,6 +625,8 @@ dependencies = [
"parking_lot", "parking_lot",
"rayon", "rayon",
"regex", "regex",
"strum",
"strum_macros",
"test-case", "test-case",
] ]
@ -1116,6 +1118,25 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.66",
]
[[package]] [[package]]
name = "syn" name = "syn"
version = "1.0.109" version = "1.0.109"

View File

@ -51,3 +51,12 @@ Use ``nix develop`` in this repository to enter a development shell.
If you are using a different shell than bash you can use e.g. ``nix develop --command fish``. If you are using a different shell than bash you can use e.g. ``nix develop --command fish``.
Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``. Build NAC3 with ``cargo build --release``. See the demonstrations in ``nac3artiq`` and ``nac3standalone``.
### Pre-Commit Hooks
You are strongly recommended to use the provided pre-commit hooks to automatically reformat files and check for non-optimal Rust practices using Clippy. Run `pre-commit install` to install the hook and `pre-commit` will automatically run `cargo fmt` and `cargo clippy` for you.
Several things to note:
- If `cargo fmt` or `cargo clippy` returns an error, the pre-commit hook will fail. You should fix all errors before trying to commit again.
- If `cargo fmt` reformats some files, the pre-commit hook will also fail. You should review the changes and, if satisfied, try to commit again.

View File

@ -159,6 +159,7 @@
# development tools # development tools
cargo-insta cargo-insta
clippy clippy
pre-commit
rustfmt rustfmt
]; ];
}; };

View File

@ -6,21 +6,20 @@ use nac3core::{
CodeGenContext, CodeGenerator, CodeGenContext, CodeGenerator,
}, },
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, GenCall, helper::PRIMITIVE_DEF_IDS}, toplevel::{helper::PrimDef, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap} typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
}; };
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef}; use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{ use inkwell::{
context::Context, context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
module::Linkage,
types::IntType,
values::BasicValueEnum,
AddressSpace,
}; };
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}}; use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns}; use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -46,7 +45,7 @@ enum ParallelMode {
/// ///
/// Each function call within the `with` block (except those within a nested `sequential` block) /// Each function call within the `with` block (except those within a nested `sequential` block)
/// are treated to be executed in parallel. /// are treated to be executed in parallel.
Deep Deep,
} }
pub struct ArtiqCodeGenerator<'a> { pub struct ArtiqCodeGenerator<'a> {
@ -96,14 +95,13 @@ impl<'a> ArtiqCodeGenerator<'a> {
/// ///
/// Direct-`parallel` block context refers to when the generator is generating statements whose /// Direct-`parallel` block context refers to when the generator is generating statements whose
/// closest parent `with` statement is a `with parallel` block. /// closest parent `with` statement is a `with parallel` block.
fn timeline_reset_start( fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
&mut self,
ctx: &mut CodeGenContext<'_, '_>
) -> Result<(), String> {
if let Some(start) = self.start.clone() { if let Some(start) = self.start.clone() {
let start_val = self.gen_expr(ctx, &start)? let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
.unwrap() ctx,
.to_basic_value_enum(ctx, self, start.custom.unwrap())?; self,
start.custom.unwrap(),
)?;
self.timeline.emit_at_mu(ctx, start_val); self.timeline.emit_at_mu(ctx, start_val);
} }
@ -129,20 +127,20 @@ impl<'a> ArtiqCodeGenerator<'a> {
store_name: Option<&str>, store_name: Option<&str>,
) -> Result<(), String> { ) -> Result<(), String> {
if let Some(end) = end { if let Some(end) = end {
let old_end = self.gen_expr(ctx, &end)? let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
.unwrap()
.to_basic_value_enum(ctx, self, end.custom.unwrap())?;
let now = self.timeline.emit_now_mu(ctx);
let max = call_int_smax(
ctx,
old_end.into_int_value(),
now.into_int_value(),
Some("smax")
);
let end_store = self.gen_store_target(
ctx, ctx,
&end, self,
store_name.map(|name| format!("{name}.addr")).as_deref())? end.custom.unwrap(),
)?;
let now = self.timeline.emit_now_mu(ctx);
let max =
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
let end_store = self
.gen_store_target(
ctx,
&end,
store_name.map(|name| format!("{name}.addr")).as_deref(),
)?
.unwrap(); .unwrap();
ctx.builder.build_store(end_store, max).unwrap(); ctx.builder.build_store(end_store, max).unwrap();
} }
@ -164,11 +162,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
} }
} }
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item=&'c Stmt<Option<Type>>>>( fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
&mut self, &mut self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I stmts: I,
) -> Result<(), String> where Self: Sized { ) -> Result<(), String>
where
Self: Sized,
{
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement // Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
// in the parallel block // in the parallel block
if self.parallel_mode == ParallelMode::Legacy { if self.parallel_mode == ParallelMode::Legacy {
@ -212,9 +213,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::With { items, body, .. } = &stmt.node else { let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
unreachable!()
};
if items.len() == 1 && items[0].optional_vars.is_none() { if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0]; let item = &items[0];
@ -239,9 +238,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let old_parallel_mode = self.parallel_mode; let old_parallel_mode = self.parallel_mode;
let now = if let Some(old_start) = &old_start { let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)? self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
.unwrap() ctx,
.to_basic_value_enum(ctx, self, old_start.custom.unwrap())? self,
old_start.custom.unwrap(),
)?
} else { } else {
self.timeline.emit_now_mu(ctx) self.timeline.emit_now_mu(ctx)
}; };
@ -259,7 +260,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let start_expr = Located { let start_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { id: start, ctx: name_ctx.clone() }, node: ExprKind::Name { id: start, ctx: *name_ctx },
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let start = self let start = self
@ -274,12 +275,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let end_expr = Located { let end_expr = Located {
// location does not matter at this point // location does not matter at this point
location: stmt.location, location: stmt.location,
node: ExprKind::Name { id: end, ctx: name_ctx.clone() }, node: ExprKind::Name { id: end, ctx: *name_ctx },
custom: Some(ctx.primitives.int64), custom: Some(ctx.primitives.int64),
}; };
let end = self let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
.gen_store_target(ctx, &end_expr, Some("end.addr"))?
.unwrap();
ctx.builder.build_store(end, now).unwrap(); ctx.builder.build_store(end, now).unwrap();
self.end = Some(end_expr); self.end = Some(end_expr);
self.name_counter += 1; self.name_counter += 1;
@ -309,10 +308,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// set duration // set duration
let end_expr = self.end.take().unwrap(); let end_expr = self.end.take().unwrap();
let end_val = self let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
.gen_expr(ctx, &end_expr)? ctx,
.unwrap() self,
.to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?; end_expr.custom.unwrap(),
)?;
// inside a sequential block // inside a sequential block
if old_start.is_none() { if old_start.is_none() {
@ -416,7 +416,7 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false); let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false); let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags // -- setup rpc tags
let mut tag = Vec::new(); let mut tag = Vec::new();
if obj.is_some() { if obj.is_some() {
@ -442,7 +442,7 @@ fn rpc_codegen_callback_fn<'ctx>(
format!("tagptr{}", fun.1 .0).as_str(), format!("tagptr{}", fun.1 .0).as_str(),
); );
tag_arr_ptr.set_initializer(&int8.const_array( tag_arr_ptr.set_initializer(&int8.const_array(
&tag.iter().map(|v| int8.const_int(*v as u64, false)).collect::<Vec<_>>(), &tag.iter().map(|v| int8.const_int(u64::from(*v), false)).collect::<Vec<_>>(),
)); ));
tag_arr_ptr.set_linkage(Linkage::Private); tag_arr_ptr.set_linkage(Linkage::Private);
let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash); let tag_ptr = ctx.module.add_global(tag_ptr_type, None, &hash);
@ -461,7 +461,8 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_length = args.len() + usize::from(obj.is_some()); let arg_length = args.len() + usize::from(obj.is_some());
let stackptr = call_stacksave(ctx, Some("rpc.stack")); let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx.builder let args_ptr = ctx
.builder
.build_array_alloca( .build_array_alloca(
ptr_type, ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false), ctx.ctx.i32_type().const_int(arg_length as u64, false),
@ -477,10 +478,8 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
// default value handling // default value handling
for k in keys { for k in keys {
mapping.insert( mapping
k.name, .insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()
);
} }
// reorder the parameters // reorder the parameters
let mut real_params = fun let mut real_params = fun
@ -499,7 +498,8 @@ fn rpc_codegen_callback_fn<'ctx>(
} }
for (i, arg) in real_params.iter().enumerate() { for (i, arg) in real_params.iter().enumerate() {
let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap(); let arg_slot =
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap(); ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap(); let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_ptr = unsafe { let arg_ptr = unsafe {
@ -508,7 +508,8 @@ fn rpc_codegen_callback_fn<'ctx>(
&[int32.const_int(i as u64, false)], &[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"), &format!("rpc.arg{i}"),
) )
}.unwrap(); }
.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap(); ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
} }
@ -528,11 +529,7 @@ fn rpc_codegen_callback_fn<'ctx>(
) )
}); });
ctx.builder ctx.builder
.build_call( .build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
rpc_send,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
)
.unwrap(); .unwrap();
// reclaim stack space used by arguments // reclaim stack space used by arguments
@ -575,13 +572,9 @@ fn rpc_codegen_callback_fn<'ctx>(
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next") .build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap() .unwrap()
.into_int_value(); .into_int_value();
let is_done = ctx.builder let is_done = ctx
.build_int_compare( .builder
inkwell::IntPredicate::EQ, .build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
int32.const_zero(),
alloc_size,
"rpc.done",
)
.unwrap(); .unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap(); ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
@ -617,9 +610,15 @@ pub fn attributes_writeback(
let mut scratch_buffer = Vec::new(); let mut scratch_buffer = Vec::new();
for val in (*globals).values() { for val in (*globals).values() {
let val = val.as_ref(py); let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?; let ty = inner_resolver.get_obj_type(
py,
val,
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
)?;
if let Err(ty) = ty { if let Err(ty) = ty {
return Ok(Err(ty)) return Ok(Err(ty));
} }
let ty = ty.unwrap(); let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) { match &*ctx.unifier.get_ty(ty) {
@ -632,14 +631,19 @@ pub fn attributes_writeback(
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields { for (name, (field_ty, is_mutable)) in fields {
if !is_mutable { if !is_mutable {
continue continue;
} }
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string()); attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name); let index = ctx.get_attr_index(ty, *name);
values.push((*field_ty, ctx.build_gep_and_load( values.push((
obj.into_pointer_value(), *field_ty,
&[zero, int32.const_int(index as u64, false)], None))); ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
));
} }
} }
if !attributes.is_empty() { if !attributes.is_empty() {
@ -648,33 +652,44 @@ pub fn attributes_writeback(
pydict.set_item("fields", attributes)?; pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
} }
}, }
TypeEnum::TList { ty: elem_ty } => { TypeEnum::TList { ty: elem_ty } => {
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py); let pydict = PyDict::new(py);
pydict.set_item("obj", val)?; pydict.set_item("obj", val)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); values.push((
ty,
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
));
} }
}, }
_ => {} _ => {}
} }
} }
let fun = FunSignature { let fun = FunSignature {
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg { args: values
name: i.to_string().into(), .iter()
ty: *ty, .enumerate()
default_value: None .map(|(i, (ty, _))| FuncArg {
}).collect(), name: i.to_string().into(),
ty: *ty,
default_value: None,
})
.collect(),
ret: ctx.primitives.none, ret: ctx.primitives.none,
vars: VarMap::default() vars: VarMap::default(),
}; };
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect(); let args: Vec<_> =
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) { values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) =
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
{
return Ok(Err(e)); return Ok(Err(e));
} }
Ok(Ok(())) Ok(Ok(()))
}).unwrap()?; })
.unwrap()?;
Ok(()) Ok(())
} }

View File

@ -1,3 +1,21 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
unsafe_op_in_unsafe_fn,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs; use std::fs;
use std::io::Write; use std::io::Write;
@ -14,16 +32,16 @@ use inkwell::{
OptimizationLevel, OptimizationLevel,
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl}; use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor; use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap}; use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
use nac3parser::{ use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef}, ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program, parser::parse_program,
}; };
use pyo3::create_exception;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet}; use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::create_exception;
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
@ -46,7 +64,7 @@ use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback; use crate::codegen::attributes_writeback;
use crate::{ use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator}, codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore}, symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
}; };
mod codegen; mod codegen;
@ -138,9 +156,7 @@ impl Nac3 {
for mut stmt in parser_result { for mut stmt in parser_result {
let include = match stmt.node { let include = match stmt.node {
StmtKind::ClassDef { StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
ref decorator_list, ref mut body, ref mut bases, ..
} => {
let nac3_class = decorator_list.iter().any(|decorator| { let nac3_class = decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node { if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "nac3" id.to_string() == "nac3"
@ -160,7 +176,8 @@ impl Nac3 {
if *id == "Exception".into() { if *id == "Exception".into() {
Ok(true) Ok(true)
} else { } else {
let base_obj = module.getattr(py, id.to_string().as_str())?; let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?; let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id)) Ok(registered_class_ids.contains(&base_id))
} }
@ -341,8 +358,9 @@ impl Nac3 {
let class_obj; let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node { if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap(); let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() && if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
class.getattr("artiq_builtin").is_err() { && class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class); class_obj = Some(class);
} else { } else {
class_obj = None; class_obj = None;
@ -388,12 +406,12 @@ impl Nac3 {
let (name, def_id, ty) = composer let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path, false) .register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| { .map_err(|e| {
CompileError::new_err(format!( CompileError::new_err(format!("compilation failed\n----------\n{e}"))
"compilation failed\n----------\n{e}"
))
})?; })?;
if let Some(class_obj) = class_obj { if let Some(class_obj) = class_obj {
self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?); self.exception_ids
.write()
.insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?);
} }
match &stmt.node { match &stmt.node {
@ -470,7 +488,8 @@ impl Nac3 {
exception_ids: self.exception_ids.clone(), exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(), deferred_eval_store: self.deferred_eval_store.clone(),
}); });
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; let resolver =
Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false) .register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap(); .unwrap();
@ -479,8 +498,12 @@ impl Nac3 {
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() }; FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature = store.from_signature(
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); &mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) { if let Err(e) = composer.start_analysis(true) {
@ -499,13 +522,11 @@ impl Nac3 {
msg.unwrap_or(e.iter().sorted().join("\n----------\n")) msg.unwrap_or(e.iter().sorted().join("\n----------\n"))
))) )))
} else { } else {
Err(CompileError::new_err( Err(CompileError::new_err(format!(
format!( "compilation failed\n----------\n{}",
"compilation failed\n----------\n{}", e.iter().sorted().join("\n----------\n"),
e.iter().sorted().join("\n----------\n"), )))
), };
))
}
} }
let top_level = Arc::new(composer.make_top_level_context()); let top_level = Arc::new(composer.make_top_level_context());
@ -533,7 +554,9 @@ impl Nac3 {
py, py,
( (
id.0.into_py(py), id.0.into_py(py),
class_def.getattr(py, name.to_string().as_str()).unwrap(), class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
), ),
) )
.unwrap(); .unwrap();
@ -548,7 +571,8 @@ impl Nac3 {
let defs = top_level.definitions.read(); let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write(); let mut definition = defs[def_id.0].write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition else { &mut *definition
else {
unreachable!() unreachable!()
}; };
@ -570,8 +594,12 @@ impl Nac3 {
let mut store = ConcreteTypeStore::new(); let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new(); let mut cache = HashMap::new();
let signature = let signature = store.from_signature(
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache); &mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature); let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask { let attributes_writeback_task = CodeGenTask {
subst: Vec::default(), subst: Vec::default(),
@ -604,23 +632,28 @@ impl Nac3 {
let membuffer = membuffers.clone(); let membuffer = membuffers.clone();
py.allow_threads(|| { py.allow_threads(|| {
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) =
threads, WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
top_level.clone(),
&self.llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns); let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create(); let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback"); let module = context.create_module("attributes_writeback");
let builder = context.create_builder(); let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module, let (_, module, _) = gen_func_impl(
attributes_writeback_task, |generator, ctx| { &context,
&mut generator,
&registry,
builder,
module,
attributes_writeback_task,
|generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes) attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
}).unwrap(); },
)
.unwrap();
let buffer = module.write_bitcode_to_memory(); let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into(); let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer); membuffer.lock().push(buffer);
@ -636,11 +669,16 @@ impl Nac3 {
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main")) .create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap(); .unwrap();
main.link_in_module(other) main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
.map_err(|err| CompileError::new_err(err.to_string()))?;
} }
let builder = context.create_builder(); let builder = context.create_builder();
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap(); let modinit_return = main
.get_function("__modinit__")
.unwrap()
.get_last_basic_block()
.unwrap()
.get_terminator()
.unwrap();
builder.position_before(&modinit_return); builder.position_before(&modinit_return);
builder builder
.build_call( .build_call(
@ -662,10 +700,7 @@ impl Nac3 {
} }
// Demote all global variables that will not be referenced in the kernel to private // Demote all global variables that will not be referenced in the kernel to private
let preserved_symbols: Vec<&'static [u8]> = vec![ let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"];
b"typeinfo",
b"now",
];
let mut global_option = main.get_first_global(); let mut global_option = main.get_first_global();
while let Some(global) = global_option { while let Some(global) = global_option {
if !preserved_symbols.contains(&(global.get_name().to_bytes())) { if !preserved_symbols.contains(&(global.get_name().to_bytes())) {
@ -674,7 +709,9 @@ impl Nac3 {
global_option = global.get_next_global(); global_option = global.get_next_global();
} }
let target_machine = self.llvm_options.target let target_machine = self
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");
@ -738,10 +775,7 @@ impl Nac3 {
} }
} }
fn link_with_lld( fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
elf_filename: String,
obj_filename: String,
) -> PyResult<()>{
let linker_args = vec![ let linker_args = vec![
"-shared".to_string(), "-shared".to_string(),
"--eh-frame-hdr".to_string(), "--eh-frame-hdr".to_string(),
@ -760,9 +794,7 @@ fn link_with_lld(
return Err(CompileError::new_err("failed to start linker")); return Err(CompileError::new_err("failed to start linker"));
} }
} else { } else {
return Err(CompileError::new_err( return Err(CompileError::new_err("linker returned non-zero status code"));
"linker returned non-zero status code",
));
} }
Ok(()) Ok(())
@ -772,7 +804,7 @@ fn add_exceptions(
composer: &mut TopLevelComposer, composer: &mut TopLevelComposer,
builtin_def: &mut HashMap<StrRef, DefinitionId>, builtin_def: &mut HashMap<StrRef, DefinitionId>,
builtin_ty: &mut HashMap<StrRef, Type>, builtin_ty: &mut HashMap<StrRef, Type>,
error_names: &[&str] error_names: &[&str],
) -> Vec<Type> { ) -> Vec<Type> {
let mut types = Vec::new(); let mut types = Vec::new();
// note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}" // note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
@ -785,7 +817,7 @@ fn add_exceptions(
// constructor id // constructor id
def_id + 1, def_id + 1,
&mut composer.unifier, &mut composer.unifier,
&composer.primitives_ty &composer.primitives_ty,
); );
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None)); composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
@ -834,7 +866,8 @@ impl Nac3 {
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg); time_fns.emit_at_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -852,7 +885,8 @@ impl Nac3 {
}, },
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| { Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty; let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap(); let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg); time_fns.emit_delay_mu(ctx, arg);
Ok(None) Ok(None)
}))), }))),
@ -867,8 +901,9 @@ impl Nac3 {
let types_mod = PyModule::import(py, "types").unwrap(); let types_mod = PyModule::import(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap(); let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),)) let get_attr_id = |obj: &PyModule, attr| {
.unwrap().extract().unwrap(); id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId { let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()), virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
generic_alias: ( generic_alias: (
@ -877,7 +912,9 @@ impl Nac3 {
), ),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()), none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"), typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()), const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
),
int: get_attr_id(builtins_mod, "int"), int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"), int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"), int64: get_attr_id(numpy_mod, "int64"),
@ -911,7 +948,7 @@ impl Nac3 {
llvm_options: CodeGenLLVMOptions { llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: Nac3::get_llvm_target_options(isa), target: Nac3::get_llvm_target_options(isa),
} },
}) })
} }
@ -952,7 +989,7 @@ impl Nac3 {
py: Python, py: Python,
) -> PyResult<()> { ) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine(); let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host { if self.isa == Isa::Host {
let link_fn = |module: &Module| { let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned(); let working_directory = self.working_directory.path().to_owned();
@ -961,7 +998,7 @@ impl Nac3 {
.expect("couldn't write module to file"); .expect("couldn't write module to file");
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(()) Ok(())
}; };
@ -997,7 +1034,7 @@ impl Nac3 {
py: Python, py: Python,
) -> PyResult<PyObject> { ) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine(); let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host { if self.isa == Isa::Host {
let link_fn = |module: &Module| { let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned(); let working_directory = self.working_directory.path().to_owned();
@ -1009,7 +1046,7 @@ impl Nac3 {
let filename = filename_path.to_str().unwrap(); let filename = filename_path.to_str().unwrap();
link_with_lld( link_with_lld(
filename.to_string(), filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string() working_directory.join("module.o").to_string_lossy().to_string(),
)?; )?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into()) Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())

View File

@ -3,10 +3,9 @@ use nac3core::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum}, symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{ toplevel::{
DefinitionId, helper::PrimDef,
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef, DefinitionId, TopLevelDef,
}, },
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
@ -22,9 +21,9 @@ use pyo3::{
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
sync::{ sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc, Arc,
atomic::{AtomicBool, Ordering::Relaxed} },
}
}; };
use crate::PrimitivePythonId; use crate::PrimitivePythonId;
@ -58,7 +57,7 @@ impl DeferredEvaluationStore {
} }
} }
/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the /// A class field as stored in the [`InnerResolver`], represented by the ID and name of the
/// associated [`PythonValue`]. /// associated [`PythonValue`].
type ResolverField = (u64, StrRef); type ResolverField = (u64, StrRef);
/// A class field as stored in Python, represented by the `id()` and [`PyObject`] of the field. /// A class field as stored in Python, represented by the `id()` and [`PyObject`] of the field.
@ -114,27 +113,27 @@ impl StaticValue for PythonValue {
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator, _: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx> { ) -> BasicValueEnum<'ctx> {
ctx.module ctx.module.get_global(format!("{}_const", self.id).as_str()).map_or_else(
.get_global(format!("{}_const", self.id).as_str()) || {
.map_or_else( Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
|| Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> { let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?; let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false); let global = ctx.module.add_global(
let global = ctx.module.add_global( struct_type,
struct_type, None,
None, format!("{}_const", self.id).as_str(),
format!("{}_const", self.id).as_str(), );
); global.set_constant(true);
global.set_constant(true); global.set_initializer(&ctx.ctx.const_struct(
global.set_initializer(&ctx.ctx.const_struct( &[ctx.ctx.i32_type().const_int(u64::from(id), false).into()],
&[ctx.ctx.i32_type().const_int(id as u64, false).into()], false,
false, ));
)); Ok(global.as_pointer_value().into())
Ok(global.as_pointer_value().into()) })
}) .unwrap()
.unwrap(), },
|val| val.as_pointer_value().into(), |val| val.as_pointer_value().into(),
) )
} }
fn to_basic_value_enum<'ctx, 'a>( fn to_basic_value_enum<'ctx, 'a>(
@ -147,10 +146,14 @@ impl StaticValue for PythonValue {
return Ok(match val { return Ok(match val {
PrimitiveValue::I32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(), PrimitiveValue::I32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(),
PrimitiveValue::I64(val) => ctx.ctx.i64_type().const_int(*val as u64, false).into(), PrimitiveValue::I64(val) => ctx.ctx.i64_type().const_int(*val as u64, false).into(),
PrimitiveValue::U32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(), PrimitiveValue::U32(val) => {
ctx.ctx.i32_type().const_int(u64::from(*val), false).into()
}
PrimitiveValue::U64(val) => ctx.ctx.i64_type().const_int(*val, false).into(), PrimitiveValue::U64(val) => ctx.ctx.i64_type().const_int(*val, false).into(),
PrimitiveValue::F64(val) => ctx.ctx.f64_type().const_float(*val).into(), PrimitiveValue::F64(val) => ctx.ctx.f64_type().const_float(*val).into(),
PrimitiveValue::Bool(val) => ctx.ctx.i8_type().const_int(*val as u64, false).into(), PrimitiveValue::Bool(val) => {
ctx.ctx.i8_type().const_int(u64::from(*val), false).into()
}
}); });
} }
if let Some(global) = ctx.module.get_global(&self.id.to_string()) { if let Some(global) = ctx.module.get_global(&self.id.to_string()) {
@ -161,7 +164,8 @@ impl StaticValue for PythonValue {
self.resolver self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty) .get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.map(Option::unwrap) .map(Option::unwrap)
}).map_err(|e| e.to_string()) })
.map_err(|e| e.to_string())
} }
fn get_field<'ctx>( fn get_field<'ctx>(
@ -186,7 +190,7 @@ impl StaticValue for PythonValue {
Ok(None) Ok(None)
} else { } else {
Ok(Some((id, obj))) Ok(Some((id, obj)))
} };
} }
let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() }; let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() };
let mut mutable = true; let mut mutable = true;
@ -264,9 +268,7 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))?? .map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))??
{ {
Ok(t) => t, Ok(t) => t,
Err(e) => { Err(e) => return Ok(Err(format!("type error ({e}) at element #{i} of the list"))),
return Ok(Err(format!("type error ({e}) at element #{i} of the list")))
}
}; };
ty = match unifier.unify(ty, b) { ty = match unifier.unify(ty, b) {
Ok(()) => ty, Ok(()) => ty,
@ -377,7 +379,7 @@ impl InnerResolver {
let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?; let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?;
if constr_id == self.primitive_ids.const_generic_marker { if constr_id == self.primitive_ids.const_generic_marker {
is_const_generic = true; is_const_generic = true;
continue continue;
} }
if !is_const_generic && needs_defer { if !is_const_generic && needs_defer {
@ -406,11 +408,11 @@ impl InnerResolver {
} }
if !is_const_generic && needs_defer { if !is_const_generic && needs_defer {
self.deferred_eval_store.store.write() self.deferred_eval_store.store.write().push((
.push((result.clone(), result.clone(),
constraints.extract()?, constraints.extract()?,
pyty.getattr("__name__")?.extract::<String>()? pyty.getattr("__name__")?.extract::<String>()?,
)); ));
} }
(result, is_const_generic) (result, is_const_generic)
@ -418,7 +420,10 @@ impl InnerResolver {
let res = if is_const_generic { let res = if is_const_generic {
if constraint_types.len() != 1 { if constraint_types.len() != 1 {
return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len()))) return Ok(Err(format!(
"ConstGeneric expects 1 argument, got {}",
constraint_types.len()
)));
} }
unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0 unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0
@ -468,7 +473,7 @@ impl InnerResolver {
))); )));
} }
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
if args.len() != 2 { if args.len() != 2 {
return Ok(Err(format!( return Ok(Err(format!(
"type list needs exactly 2 type parameters, found {}", "type list needs exactly 2 type parameters, found {}",
@ -572,9 +577,7 @@ impl InnerResolver {
let str_fn = let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap(); pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap(); let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!( Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)")))
"{str_repr} is not registered with NAC3 (@nac3 decorator missing?)"
)))
} }
} }
@ -589,31 +592,28 @@ impl InnerResolver {
let ty = self.helper.type_fn.call1(py, (obj,)).unwrap(); let ty = self.helper.type_fn.call1(py, (obj,)).unwrap();
let py_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let py_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if let Some(ty) = self.pyid_to_type.read().get(&py_obj_id) { if let Some(ty) = self.pyid_to_type.read().get(&py_obj_id) {
return Ok(Ok(*ty)) return Ok(Ok(*ty));
} }
// check if constructor function exists in the methods list // check if constructor function exists in the methods list
let pyid_to_def = self.pyid_to_def.read(); let pyid_to_def = self.pyid_to_def.read();
let constructor_ty = pyid_to_def let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
.get(&py_obj_id) defs.iter().find_map(|def| {
.and_then(|def_id| { if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
defs if object_id == def_id
.iter() && constructor.is_some()
.find_map(|def| { && methods.iter().any(|(s, _, _)| s == &"__init__".into())
if let TopLevelDef::Class { {
object_id, methods, constructor, .. return *constructor;
} = &*def.read() {
if object_id == def_id && constructor.is_some() && methods.iter().any(|(s, _, _)| s == &"__init__".into()) {
return *constructor;
}
} }
None }
}) None
}); })
});
if let Some(ty) = constructor_ty { if let Some(ty) = constructor_ty {
self.pyid_to_type.write().insert(py_obj_id, ty); self.pyid_to_type.write().insert(py_obj_id, ty);
return Ok(Ok(ty)) return Ok(Ok(ty));
} }
let (extracted_ty, inst_check) = match self.get_pyty_obj_type( let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
@ -664,7 +664,7 @@ impl InnerResolver {
} }
} }
} }
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 { if len == 0 {
@ -680,12 +680,8 @@ impl InnerResolver {
match actual_ty { match actual_ty {
Ok(t) => match unifier.unify(ty, t) { Ok(t) => match unifier.unify(ty, t) {
Ok(()) => { Ok(()) => {
let ndarray_ty = make_ndarray_ty( let ndarray_ty =
unifier, make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
primitives,
Some(ty),
Some(ndims),
);
Ok(Ok(ndarray_ty)) Ok(Ok(ndarray_ty))
} }
@ -726,7 +722,8 @@ impl InnerResolver {
let var_map = params let var_map = params
.iter() .iter()
.map(|(id_var, ty)| { .map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
else {
unreachable!() unreachable!()
}; };
@ -734,7 +731,7 @@ impl InnerResolver {
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
}) })
.collect::<VarMap>(); .collect::<VarMap>();
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
} }
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
@ -754,8 +751,8 @@ impl InnerResolver {
let var_map = params let var_map = params
.iter() .iter()
.map(|(id_var, ty)| { .map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
&*unifier.get_ty(*ty) else { else {
unreachable!() unreachable!()
}; };
@ -767,25 +764,23 @@ impl InnerResolver {
// loop through non-function fields of the class to get the instantiated value // loop through non-function fields of the class to get the instantiated value
for field in fields { for field in fields {
let name: String = (*field.0).into(); let name: String = (*field.0).into();
if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) { if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) {
continue; continue;
} }
let field_data = match obj.getattr(name.as_str()) { let field_data = match obj.getattr(name.as_str()) {
Ok(d) => d, Ok(d) => d,
Err(e) => return Ok(Err(format!("{e}"))), Err(e) => return Ok(Err(format!("{e}"))),
}; };
let ty = match self let ty =
.get_obj_type(py, field_data, unifier, defs, primitives)? match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
{ Ok(t) => t,
Ok(t) => t, Err(e) => {
Err(e) => { return Ok(Err(format!(
return Ok(Err(format!( "error when getting type of field `{name}` ({e})"
"error when getting type of field `{name}` ({e})" )))
))) }
} };
}; let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0);
let field_ty =
unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
if let Err(e) = unifier.unify(ty, field_ty) { if let Err(e) = unifier.unify(ty, field_ty) {
// field type mismatch // field type mismatch
return Ok(Err(format!( return Ok(Err(format!(
@ -800,14 +795,15 @@ impl InnerResolver {
return Ok(Err("object is not of concrete type".into())); return Ok(Err("object is not of concrete type".into()));
} }
} }
let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty); let extracted_ty =
unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
Ok(Ok(extracted_ty)) Ok(Ok(extracted_ty))
}; };
let result = instantiate_obj(); let result = instantiate_obj();
// update/remove the cache according to the result // update/remove the cache according to the result
match result { match result {
Ok(Ok(ty)) => self.pyid_to_type.write().insert(py_obj_id, ty), Ok(Ok(ty)) => self.pyid_to_type.write().insert(py_obj_id, ty),
_ => self.pyid_to_type.write().remove(&py_obj_id) _ => self.pyid_to_type.write().remove(&py_obj_id),
}; };
result result
} }
@ -816,32 +812,32 @@ impl InnerResolver {
if unifier.unioned(extracted_ty, primitives.int32) { if unifier.unioned(extracted_ty, primitives.int32) {
obj.extract::<i32>().map_or_else( obj.extract::<i32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int32"))), |_| Ok(Err(format!("{obj} is not in the range of int32"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else if unifier.unioned(extracted_ty, primitives.int64) { } else if unifier.unioned(extracted_ty, primitives.int64) {
obj.extract::<i64>().map_or_else( obj.extract::<i64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int64"))), |_| Ok(Err(format!("{obj} is not in the range of int64"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else if unifier.unioned(extracted_ty, primitives.uint32) { } else if unifier.unioned(extracted_ty, primitives.uint32) {
obj.extract::<u32>().map_or_else( obj.extract::<u32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint32"))), |_| Ok(Err(format!("{obj} is not in the range of uint32"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else if unifier.unioned(extracted_ty, primitives.uint64) { } else if unifier.unioned(extracted_ty, primitives.uint64) {
obj.extract::<u64>().map_or_else( obj.extract::<u64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint64"))), |_| Ok(Err(format!("{obj} is not in the range of uint64"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else if unifier.unioned(extracted_ty, primitives.bool) { } else if unifier.unioned(extracted_ty, primitives.bool) {
obj.extract::<bool>().map_or_else( obj.extract::<bool>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of bool"))), |_| Ok(Err(format!("{obj} is not in the range of bool"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else if unifier.unioned(extracted_ty, primitives.float) { } else if unifier.unioned(extracted_ty, primitives.float) {
obj.extract::<f64>().map_or_else( obj.extract::<f64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of float64"))), |_| Ok(Err(format!("{obj} is not in the range of float64"))),
|_| Ok(Ok(extracted_ty)) |_| Ok(Ok(extracted_ty)),
) )
} else { } else {
Ok(Ok(extracted_ty)) Ok(Ok(extracted_ty))
@ -872,7 +868,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.uint32 { } else if ty_id == self.primitive_ids.uint32 {
let val: u32 = obj.extract().unwrap(); let val: u32 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::U32(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::U32(val));
Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) Ok(Some(ctx.ctx.i32_type().const_int(u64::from(val), false).into()))
} else if ty_id == self.primitive_ids.uint64 { } else if ty_id == self.primitive_ids.uint64 {
let val: u64 = obj.extract().unwrap(); let val: u64 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val));
@ -880,7 +876,7 @@ impl InnerResolver {
} else if ty_id == self.primitive_ids.bool { } else if ty_id == self.primitive_ids.bool {
let val: bool = obj.extract().unwrap(); let val: bool = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val));
Ok(Some(ctx.ctx.i8_type().const_int(val as u64, false).into())) Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into()))
} else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 {
let val: f64 = obj.extract().unwrap(); let val: f64 = obj.extract().unwrap();
self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val));
@ -893,8 +889,8 @@ impl InnerResolver {
} }
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let elem_ty = let elem_ty = if let TypeEnum::TList { ty } =
if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() ctx.unifier.get_ty_immutable(expected_ty).as_ref()
{ {
*ty *ty
} else { } else {
@ -918,13 +914,11 @@ impl InnerResolver {
let arr: Result<Option<Vec<_>>, _> = (0..len) let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| { .map(|i| {
obj obj.get_item(i).and_then(|elem| {
.get_item(i) self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| {
.and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty) super::CompileError::new_err(format!("Error getting element {i}: {e}"))
.map_err( })
|e| super::CompileError::new_err( })
format!("Error getting element {i}: {e}"))
))
}) })
.collect(); .collect();
let arr = arr?.unwrap(); let arr = arr?.unwrap();
@ -956,7 +950,10 @@ impl InnerResolver {
arr_global.set_initializer(&arr); arr_global.set_initializer(&arr);
let val = arr_ty.const_named_struct(&[ let val = arr_ty.const_named_struct(&[
arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::default())).into(), arr_global
.as_pointer_value()
.const_cast(ty.ptr_type(AddressSpace::default()))
.into(),
size_t.const_int(len as u64, false).into(), size_t.const_int(len as u64, false).into(),
]); ]);
@ -968,25 +965,21 @@ impl InnerResolver {
todo!() todo!()
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
unreachable!()
};
let tup_tys = ty.iter(); let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?; let elements: &PyTuple = obj.downcast()?;
assert_eq!(elements.len(), tup_tys.len()); assert_eq!(elements.len(), tup_tys.len());
let val: Result<Option<Vec<_>>, _> = let val: Result<Option<Vec<_>>, _> = elements
elements .iter()
.iter() .enumerate()
.enumerate() .zip(tup_tys)
.zip(tup_tys) .map(|((i, elem), ty)| {
.map(|((i, elem), ty)| self self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| {
.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| super::CompileError::new_err(format!("Error getting element {i}: {e}"))
super::CompileError::new_err( })
format!("Error getting element {i}: {e}") })
) .collect();
)
).collect();
let val = val?.unwrap(); let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false); let val = ctx.ctx.const_struct(&val, false);
Ok(Some(val.into())) Ok(Some(val.into()))
@ -997,7 +990,7 @@ impl InnerResolver {
{ {
*params.iter().next().unwrap().1 *params.iter().next().unwrap().1
} }
_ => unreachable!("must be option type") _ => unreachable!("must be option type"),
}; };
if id == self.primitive_ids.none { if id == self.primitive_ids.none {
// for option type, just a null ptr // for option type, just a null ptr
@ -1009,7 +1002,13 @@ impl InnerResolver {
)) ))
} else { } else {
match self match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty) .get_obj_value(
py,
obj.getattr("_nac3_option").unwrap(),
ctx,
generator,
option_val_ty,
)
.map_err(|e| { .map_err(|e| {
super::CompileError::new_err(format!( super::CompileError::new_err(format!(
"Error getting value of Option object: {e}" "Error getting value of Option object: {e}"
@ -1019,17 +1018,26 @@ impl InnerResolver {
let global_str = format!("{id}_option"); let global_str = format!("{id}_option");
{ {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { let global =
ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str) ctx.module.get_global(&global_str).unwrap_or_else(|| {
}); ctx.module.add_global(
v.get_type(),
Some(AddressSpace::default()),
&global_str,
)
});
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} }
self.global_value_ids.write().insert(id, obj.into()); self.global_value_ids.write().insert(id, obj.into());
} }
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str); let global = ctx.module.add_global(
v.get_type(),
Some(AddressSpace::default()),
&global_str,
);
global.set_initializer(&v); global.set_initializer(&v);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
}, }
None => Ok(None), None => Ok(None),
} }
} }
@ -1066,8 +1074,16 @@ impl InnerResolver {
let values: Result<Option<Vec<_>>, _> = fields let values: Result<Option<Vec<_>>, _> = fields
.iter() .iter()
.map(|(name, ty, _)| { .map(|(name, ty, _)| {
self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) self.get_obj_value(
.map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}"))) py,
obj.getattr(name.to_string().as_str())?,
ctx,
generator,
*ty,
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
})
}) })
.collect(); .collect();
let values = values?; let values = values?;
@ -1119,8 +1135,7 @@ impl InnerResolver {
if id == self.primitive_ids.none { if id == self.primitive_ids.none {
Ok(SymbolValue::OptionNone) Ok(SymbolValue::OptionNone)
} else { } else {
self self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())?
.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())?
.map(|v| SymbolValue::OptionSome(Box::new(v))) .map(|v| SymbolValue::OptionSome(Box::new(v)))
} }
} else { } else {
@ -1149,7 +1164,8 @@ impl SymbolResolver for Resolver {
} }
} }
Ok(sym_value) Ok(sym_value)
}).unwrap() })
.unwrap()
} }
fn get_symbol_type( fn get_symbol_type(
@ -1166,7 +1182,7 @@ impl SymbolResolver for Resolver {
Ok(ty) Ok(ty)
} else { } else {
let Some(id) = self.0.name_to_pyid.get(&str) else { let Some(id) = self.0.name_to_pyid.get(&str) else {
return Err(format!("cannot find symbol `{str}`")) return Err(format!("cannot find symbol `{str}`"));
}; };
let result = if let Some(t) = { let result = if let Some(t) = {
let pyid_to_type = self.0.pyid_to_type.read(); let pyid_to_type = self.0.pyid_to_type.read();
@ -1191,7 +1207,8 @@ impl SymbolResolver for Resolver {
} }
} }
Ok(sym_ty) Ok(sym_ty)
}).unwrap() })
.unwrap()
}; };
result result
} }
@ -1242,15 +1259,16 @@ impl SymbolResolver for Resolver {
id_to_def.get(&id).copied().ok_or_else(String::new) id_to_def.get(&id).copied().ok_or_else(String::new)
} }
.or_else(|_| { .or_else(|_| {
let py_id = self.0.name_to_pyid.get(&id) let py_id = self
.ok_or_else(|| HashSet::from([ .0
format!("Undefined identifier `{id}`"), .name_to_pyid
]))?; .get(&id)
let result = self.0.pyid_to_def.read().get(py_id) .ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))?;
.copied() let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or_else(|| {
.ok_or_else(|| HashSet::from([ HashSet::from([format!(
format!("`{id}` is not registered with NAC3 (@nac3 decorator missing?)"), "`{id}` is not registered with NAC3 (@nac3 decorator missing?)"
]))?; )])
})?;
self.0.id_to_def.write().insert(id, result); self.0.id_to_def.write().insert(id, result);
Ok(result) Ok(result)
}) })
@ -1274,7 +1292,7 @@ impl SymbolResolver for Resolver {
&self, &self,
unifier: &mut Unifier, unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>], defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
// we don't need a lock because this will only be run in a single thread // we don't need a lock because this will only be run in a single thread
if self.0.deferred_eval_store.needs_defer.load(Relaxed) { if self.0.deferred_eval_store.needs_defer.load(Relaxed) {
@ -1304,7 +1322,8 @@ impl SymbolResolver for Resolver {
} }
} }
Ok(Ok(())) Ok(Ok(()))
}).unwrap()?; })
.unwrap()?;
} }
Ok(()) Ok(())
} }

View File

@ -1,10 +1,12 @@
use inkwell::{values::{BasicValueEnum, CallSiteValue}, AddressSpace, AtomicOrdering}; use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either; use itertools::Either;
use nac3core::codegen::CodeGenContext; use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline. /// Functions for manipulating the timeline.
pub trait TimeFns { pub trait TimeFns {
/// Emits LLVM IR for `now_mu`. /// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>; fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
@ -27,26 +29,31 @@ impl TimeFns for NowPinningTimeFns64 {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap(); }
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi") let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo") let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap(); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder let shifted_hi =
.build_left_shift(zext_hi, i64_type.const_int(32, false), "") ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
.unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap(); let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap() ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
} }
@ -58,7 +65,8 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value(); let time = t.into_int_value();
let time_hi = ctx.builder let time_hi = ctx
.builder
.build_int_truncate( .build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(), ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
@ -70,14 +78,16 @@ impl TimeFns for NowPinningTimeFns64 {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap(); }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap() .unwrap()
@ -90,50 +100,49 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let now = ctx let now = ctx
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap(); }
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi") let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo") let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let dt = dt.into_int_value(); let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap(); let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder let shifted_hi =
.build_left_shift(zext_hi, i64_type.const_int(32, false), "") ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
.unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap(); let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap(); let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap(); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder let time_hi = ctx
.builder
.build_int_truncate( .build_int_truncate(
ctx.builder.build_right_shift( ctx.builder
time, .build_right_shift(time, i64_type.const_int(32, false), false, "")
i64_type.const_int(32, false), .unwrap(),
false,
"",
).unwrap(),
i32_type, i32_type,
"time.hi", "time.hi",
) )
@ -164,16 +173,16 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now") let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap(); let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap(); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu") ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
.map(Into::into)
.unwrap()
} }
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) { fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -183,7 +192,8 @@ impl TimeFns for NowPinningTimeFns {
let time = t.into_int_value(); let time = t.into_int_value();
let time_hi = ctx.builder let time_hi = ctx
.builder
.build_int_truncate( .build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(), ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type, i32_type,
@ -195,14 +205,16 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap(); }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap() .unwrap()
@ -215,11 +227,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap(); .unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let i32_type = ctx.ctx.i32_type(); let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type(); let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false); let i64_32 = i64_type.const_int(32, false);
@ -227,7 +235,8 @@ impl TimeFns for NowPinningTimeFns {
.module .module
.get_global("now") .get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now")); .unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "") .build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
@ -238,7 +247,8 @@ impl TimeFns for NowPinningTimeFns {
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap(); let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap(); let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap(); let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder let time_hi = ctx
.builder
.build_int_truncate( .build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(), ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type, i32_type,
@ -246,14 +256,16 @@ impl TimeFns for NowPinningTimeFns {
) )
.unwrap(); .unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap(); let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx.builder let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr") .build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value) .map(BasicValueEnum::into_pointer_value)
.unwrap(); .unwrap();
let now_loptr = unsafe { let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr") ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap(); }
.unwrap();
ctx.builder ctx.builder
.build_store(now_hiptr, time_hi) .build_store(now_hiptr, time_hi)
.unwrap() .unwrap()
@ -276,7 +288,8 @@ impl TimeFns for ExternTimeFns {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| { let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None) ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
}); });
ctx.builder.build_call(now_mu, &[], "now_mu") ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value) .map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
@ -293,11 +306,7 @@ impl TimeFns for ExternTimeFns {
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap(); ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
} }
fn emit_delay_mu<'ctx>( fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| { let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function( ctx.module.add_function(
"delay_mu", "delay_mu",

File diff suppressed because it is too large Load Diff

View File

@ -28,12 +28,12 @@ impl From<bool> for Constant {
} }
impl From<i32> for Constant { impl From<i32> for Constant {
fn from(i: i32) -> Constant { fn from(i: i32) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
impl From<i64> for Constant { impl From<i64> for Constant {
fn from(i: i64) -> Constant { fn from(i: i64) -> Constant {
Self::Int(i as i128) Self::Int(i128::from(i))
} }
} }
@ -50,6 +50,7 @@ pub enum ConversionFlag {
} }
impl ConversionFlag { impl ConversionFlag {
#[must_use]
pub fn try_from_byte(b: u8) -> Option<Self> { pub fn try_from_byte(b: u8) -> Option<Self> {
match b { match b {
b's' => Some(Self::Str), b's' => Some(Self::Str),
@ -69,6 +70,7 @@ pub struct ConstantOptimizer {
#[cfg(feature = "constant-optimization")] #[cfg(feature = "constant-optimization")]
impl ConstantOptimizer { impl ConstantOptimizer {
#[inline] #[inline]
#[must_use]
pub fn new() -> Self { pub fn new() -> Self {
Self { _priv: () } Self { _priv: () }
} }
@ -85,33 +87,22 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> { fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node { match node.node {
crate::ExprKind::Tuple { elts, ctx } => { crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts let elts =
.into_iter() elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
.map(|x| self.fold_expr(x)) let expr =
.collect::<Result<Vec<_>, _>>()?; if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let expr = if elts let tuple = elts
.iter() .into_iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) .map(|e| match e.node {
{ crate::ExprKind::Constant { value, .. } => value,
let tuple = elts _ => unreachable!(),
.into_iter() })
.map(|e| match e.node { .collect();
crate::ExprKind::Constant { value, .. } => value, crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
_ => unreachable!(), } else {
}) crate::ExprKind::Tuple { elts, ctx }
.collect(); };
crate::ExprKind::Constant { Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
} }
_ => crate::fold::fold_expr(self, node), _ => crate::fold::fold_expr(self, node),
} }
@ -138,18 +129,12 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 1.into(), kind: None },
value: 1.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 2.into(), kind: None },
value: 2.into(),
kind: None,
},
}, },
Located { Located {
location, location,
@ -160,26 +145,17 @@ mod tests {
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 3.into(), kind: None },
value: 3.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 4.into(), kind: None },
value: 4.into(),
kind: None,
},
}, },
Located { Located {
location, location,
custom, custom,
node: ExprKind::Constant { node: ExprKind::Constant { value: 5.into(), kind: None },
value: 5.into(),
kind: None,
},
}, },
], ],
}, },
@ -187,9 +163,7 @@ mod tests {
], ],
}, },
}; };
let new_ast = ConstantOptimizer::new() let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
assert_eq!( assert_eq!(
new_ast, new_ast,
Located { Located {
@ -199,11 +173,7 @@ mod tests {
value: Constant::Tuple(vec![ value: Constant::Tuple(vec![
1.into(), 1.into(),
2.into(), 2.into(),
Constant::Tuple(vec![ Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
3.into(),
4.into(),
5.into(),
])
]), ]),
kind: None kind: None
}, },

View File

@ -64,11 +64,4 @@ macro_rules! simple_fold {
}; };
} }
simple_fold!( simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);

View File

@ -2,6 +2,7 @@ use crate::{Constant, ExprKind};
impl<U> ExprKind<U> { impl<U> ExprKind<U> {
/// Returns a short name for the node suitable for use in error messages. /// Returns a short name for the node suitable for use in error messages.
#[must_use]
pub fn name(&self) -> &'static str { pub fn name(&self) -> &'static str {
match self { match self {
ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => { ExprKind::BoolOp { .. } | ExprKind::BinOp { .. } | ExprKind::UnaryOp { .. } => {
@ -34,10 +35,7 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred", ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice", ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => { ExprKind::JoinedStr { values } => {
if values if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
"f-string expression" "f-string expression"
} else { } else {
"literal" "literal"

View File

@ -1,3 +1,19 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate lazy_static; extern crate lazy_static;
@ -9,6 +25,6 @@ mod impls;
mod location; mod location;
pub use ast_gen::*; pub use ast_gen::*;
pub use location::{Location, FileName}; pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>; pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,6 +1,6 @@
//! Datatypes to support source location information. //! Datatypes to support source location information.
use std::cmp::Ordering;
use crate::ast_gen::StrRef; use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt; use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)] #[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -22,7 +22,7 @@ impl From<String> for FileName {
pub struct Location { pub struct Location {
pub row: usize, pub row: usize,
pub column: usize, pub column: usize,
pub file: FileName pub file: FileName,
} }
impl fmt::Display for Location { impl fmt::Display for Location {
@ -35,12 +35,12 @@ impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string()); let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal { if file_cmp != Ordering::Equal {
return file_cmp return file_cmp;
} }
let row_cmp = self.row.cmp(&other.row); let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal { if row_cmp != Ordering::Equal {
return row_cmp return row_cmp;
} }
self.column.cmp(&other.column) self.column.cmp(&other.column)
@ -76,23 +76,22 @@ impl Location {
) )
} }
} }
Visualize { Visualize { loc: *self, line, desc }
loc: *self,
line,
desc,
}
} }
} }
impl Location { impl Location {
#[must_use]
pub fn new(row: usize, column: usize, file: FileName) -> Self { pub fn new(row: usize, column: usize, file: FileName) -> Self {
Location { row, column, file } Location { row, column, file }
} }
#[must_use]
pub fn row(&self) -> usize { pub fn row(&self) -> usize {
self.row self.row
} }
#[must_use]
pub fn column(&self) -> usize { pub fn column(&self) -> usize {
self.column self.column
} }

View File

@ -11,6 +11,8 @@ indexmap = "2.2"
parking_lot = "0.12" parking_lot = "0.12"
rayon = "1.8" rayon = "1.8"
nac3parser = { path = "../nac3parser" } nac3parser = { path = "../nac3parser" }
strum = "0.26.2"
strum_macros = "0.26.4"
[dependencies.inkwell] [dependencies.inkwell]
version = "0.4" version = "0.4"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -7,9 +7,9 @@ use crate::{
}, },
}; };
use indexmap::IndexMap;
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use std::collections::HashMap; use std::collections::HashMap;
use indexmap::IndexMap;
pub struct ConcreteTypeStore { pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>, store: Vec<ConcreteTypeEnum>,
@ -202,9 +202,9 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => { TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache) self.from_signature(unifier, primitives, signature, cache)
} }
TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral { TypeEnum::TLiteral { values, .. } => {
values: values.clone(), ConcreteTypeEnum::TLiteral { values: values.clone() }
}, }
_ => unreachable!("{:?}", ty_enum.get_type_name()), _ => unreachable!("{:?}", ty_enum.get_type_name()),
}; };
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() { let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
@ -292,9 +292,8 @@ impl ConcreteTypeStore {
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache))) .map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<VarMap>(), .collect::<VarMap>(),
}), }),
ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral { ConcreteTypeEnum::TLiteral { values, .. } => {
values: values.clone(), TypeEnum::TLiteral { values: values.clone(), loc: None }
loc: None,
} }
}; };
let result = unifier.add_ty(result); let result = unifier.add_ty(result);

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ pub fn call_tan<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -53,7 +53,7 @@ pub fn call_asin<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -85,7 +85,7 @@ pub fn call_acos<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -117,7 +117,7 @@ pub fn call_atan<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -149,7 +149,7 @@ pub fn call_sinh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -181,7 +181,7 @@ pub fn call_cosh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -213,7 +213,7 @@ pub fn call_tanh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -245,7 +245,7 @@ pub fn call_asinh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -277,7 +277,7 @@ pub fn call_acosh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -309,7 +309,7 @@ pub fn call_atanh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -341,7 +341,7 @@ pub fn call_expm1<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -373,7 +373,7 @@ pub fn call_cbrt<'ctx>(
for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] { for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -404,7 +404,7 @@ pub fn call_erf<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
); );
func func
@ -434,7 +434,7 @@ pub fn call_erfc<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
); );
func func
@ -465,7 +465,7 @@ pub fn call_j1<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
); );
func func
@ -498,7 +498,7 @@ pub fn call_atan2<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -533,7 +533,7 @@ pub fn call_ldexp<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] { for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
); );
} }
@ -566,7 +566,7 @@ pub fn call_hypot<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
); );
func func
@ -598,7 +598,7 @@ pub fn call_nextafter<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None); let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute( func.add_attribute(
AttributeLoc::Function, AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0) ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
); );
func func
@ -610,4 +610,4 @@ pub fn call_nextafter<'ctx>(
.map(|v| v.map_left(BasicValueEnum::into_float_value)) .map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap() .unwrap()
} }

View File

@ -1,5 +1,5 @@
use crate::{ use crate::{
codegen::{classes::ArraySliceValue, expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext}, codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type}, typecheck::typedef::{FunSignature, Type},
@ -210,7 +210,7 @@ pub trait CodeGenerator {
fn bool_to_i1<'ctx>( fn bool_to_i1<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value) bool_to_i1(&ctx.builder, bool_value)
} }
@ -219,7 +219,7 @@ pub trait CodeGenerator {
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
&self, &self,
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value) bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
} }
@ -239,7 +239,6 @@ impl DefaultCodeGenerator {
} }
impl CodeGenerator for DefaultCodeGenerator { impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`]. /// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str { fn get_name(&self) -> &str {
&self.name &self.name

View File

@ -2,18 +2,13 @@ use crate::typecheck::typedef::Type;
use super::{ use super::{
classes::{ classes::{
ArrayLikeIndexer, ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
ArrayLikeValue, TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
ArraySliceValue,
ListValue,
NDArrayValue,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
}, },
CodeGenContext, llvm_intrinsics, CodeGenContext, CodeGenerator,
CodeGenerator,
llvm_intrinsics,
}; };
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{ use inkwell::{
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
context::Context, context::Context,
@ -25,8 +20,6 @@ use inkwell::{
}; };
use itertools::Either; use itertools::Either;
use nac3parser::ast::Expr; use nac3parser::ast::Expr;
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
#[must_use] #[must_use]
pub fn load_irrt(ctx: &Context) -> Module { pub fn load_irrt(ctx: &Context) -> Module {
@ -70,12 +63,15 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function(symbol, fn_type, None) ctx.module.add_function(symbol, fn_type, None)
}); });
// throw exception when exp < 0 // throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare( let ge_zero = ctx
IntPredicate::SGE, .builder
exp, .build_int_compare(
exp.get_type().const_zero(), IntPredicate::SGE,
"assert_int_pow_ge_0", exp,
).unwrap(); exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
ge_zero, ge_zero,
@ -107,12 +103,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
}); });
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare( let not_zero = ctx
IntPredicate::NE, .builder
step, .build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
step.get_type().const_zero(), .unwrap();
"range_step_ne",
).unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
not_zero, not_zero,
@ -208,15 +202,18 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = if let Some(v) = generator.gen_expr(ctx, step)? { let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value() v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else { } else {
return Ok(None) return Ok(None);
}; };
// assert step != 0, throw exception if not // assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare( let not_zero = ctx
IntPredicate::NE, .builder
step, .build_int_compare(
step.get_type().const_zero(), IntPredicate::NE,
"range_step_ne", step,
).unwrap(); step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
not_zero, not_zero,
@ -226,25 +223,32 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
ctx.current_loc, ctx.current_loc,
); );
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap(); let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg").unwrap(); let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
( (
match s { match s {
Some(s) => { Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else { let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
ctx.builder.build_and( ctx.builder
ctx.builder.build_int_compare( .build_and(
IntPredicate::EQ, ctx.builder
s, .build_int_compare(
length, IntPredicate::EQ,
"s_eq_len", s,
).unwrap(), length,
neg, "s_eq_len",
"should_minus_one", )
).unwrap(), .unwrap(),
neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(), ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s, s,
"final_start", "final_start",
@ -252,14 +256,16 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap() .unwrap()
} }
None => ctx.builder.build_select(neg, len_id, zero, "stt") None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(), .unwrap(),
}, },
match e { match e {
Some(e) => { Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else { let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None) return Ok(None);
}; };
ctx.builder ctx.builder
.build_select( .build_select(
@ -271,7 +277,9 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap() .unwrap()
} }
None => ctx.builder.build_select(neg, zero, len_id, "end") None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(), .unwrap(),
}, },
@ -299,15 +307,16 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
let i = if let Some(v) = generator.gen_expr(ctx, i)? { let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())? v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else { } else {
return Ok(None) return Ok(None);
}; };
Ok(Some(ctx Ok(Some(
.builder ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind") .build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value) .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left) .map(Either::unwrap_left)
.unwrap())) .unwrap(),
))
} }
/// This function handles 'end' **inclusively**. /// This function handles 'end' **inclusively**.
@ -349,47 +358,33 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
let zero = int32.const_zero(); let zero = int32.const_zero();
let one = int32.const_int(1, false); let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator); let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast( let dest_arr_ptr =
dest_arr_ptr, ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
elem_ptr_type,
"dest_arr_ptr_cast",
).unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len")); let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap(); let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator); let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr = ctx.builder.build_pointer_cast( let src_arr_ptr =
src_arr_ptr, ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
elem_ptr_type,
"src_arr_ptr_cast",
).unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len")); let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap(); let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done // index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and // assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied // throw exception if not satisfied
let src_end = ctx.builder let src_end = ctx
.builder
.build_select( .build_select(
ctx.builder.build_int_compare( ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
IntPredicate::SLT,
src_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(), ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e", "final_e",
) )
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let dest_end = ctx.builder let dest_end = ctx
.builder
.build_select( .build_select(
ctx.builder.build_int_compare( ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
IntPredicate::SLT,
dest_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(), ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(), ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e", "final_e",
@ -400,24 +395,23 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2); calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len = let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2); calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare( let src_eq_dest = ctx
IntPredicate::EQ, .builder
src_slice_len, .build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
dest_slice_len, .unwrap();
"slice_src_eq_dest", let src_slt_dest = ctx
).unwrap(); .builder
let src_slt_dest = ctx.builder.build_int_compare( .build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
IntPredicate::SLT, .unwrap();
src_slice_len, let dest_step_eq_one = ctx
dest_slice_len, .builder
"slice_src_slt_dest", .build_int_compare(
).unwrap(); IntPredicate::EQ,
let dest_step_eq_one = ctx.builder.build_int_compare( dest_idx.2,
IntPredicate::EQ, dest_idx.2.get_type().const_int(1, false),
dest_idx.2, "slice_dest_step_eq_one",
dest_idx.2.get_type().const_int(1, false), )
"slice_dest_step_eq_one", .unwrap();
).unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap(); let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap(); let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert( ctx.make_assert(
@ -461,17 +455,14 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
.unwrap() .unwrap()
}; };
// update length // update length
let need_update = ctx.builder let need_update =
.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update") ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
.unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update"); let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont"); let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap(); ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb); ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len")
.unwrap();
dest_arr.store_size(ctx, generator, new_len); dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap(); ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb); ctx.builder.position_at_end(cont_bb);
@ -488,7 +479,8 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function("__nac3_isinf", fn_type, None) ctx.module.add_function("__nac3_isinf", fn_type, None)
}); });
let ret = ctx.builder let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf") .build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value) .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(|v| v.map_left(BasicValueEnum::into_int_value))
@ -509,7 +501,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function("__nac3_isnan", fn_type, None) ctx.module.add_function("__nac3_isnan", fn_type, None)
}); });
let ret = ctx.builder let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan") .build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value) .map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value)) .map(|v| v.map_left(BasicValueEnum::into_int_value))
@ -520,10 +513,7 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
} }
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result. /// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>( pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
@ -540,10 +530,7 @@ pub fn call_gamma<'ctx>(
} }
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result. /// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>( pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
@ -560,10 +547,7 @@ pub fn call_gammaln<'ctx>(
} }
/// Generates a call to `j0` in IR. Returns an `f64` representing the result. /// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>( pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type(); let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| { let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
@ -583,7 +567,7 @@ pub fn call_j0<'ctx>(
/// calculated total size. /// calculated total size.
/// ///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension. /// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for, /// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively. /// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>( pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G, generator: &G,
@ -591,9 +575,10 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
dims: &Dims, dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>), (begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx> ) -> IntValue<'ctx>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>, { Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_i64 = ctx.ctx.i64_type(); let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -602,19 +587,14 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size", 32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64", 64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw) bw => unreachable!("Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_size_fn_t = llvm_usize.fn_type( let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[ &[llvm_pi64.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
llvm_pi64.into(),
llvm_usize.into(),
llvm_usize.into(),
llvm_usize.into(),
],
false, false,
); );
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name) let ndarray_calc_size_fn =
.unwrap_or_else(|| { ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
}); });
@ -658,30 +638,22 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices", 32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64", 64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw) bw => unreachable!("Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| { let ndarray_calc_nd_indices_fn =
let fn_type = llvm_void.fn_type( ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
&[ let fn_type = llvm_void.fn_type(
llvm_usize.into(), &[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
llvm_pusize.into(), false,
llvm_usize.into(), );
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None) ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
}); });
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca( let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
llvm_i32,
ndarray_num_dims,
"",
).unwrap();
ctx.builder ctx.builder
.build_call( .build_call(
@ -709,9 +681,10 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: &Indices, indices: &Indices,
) -> IntValue<'ctx> ) -> IntValue<'ctx>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>, { Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type(); let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx); let llvm_usize = generator.get_size_type(ctx.ctx);
@ -734,26 +707,23 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() { let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index", 32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64", 64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw) bw => unreachable!("Unsupported size type bit width: {}", bw),
}; };
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| { let ndarray_flatten_index_fn =
let fn_type = llvm_usize.fn_type( ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
&[ let fn_type = llvm_usize.fn_type(
llvm_pusize.into(), &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
llvm_usize.into(), false,
llvm_pi32.into(), );
llvm_usize.into(),
],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None) ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
}); });
let ndarray_num_dims = ndarray.load_ndims(ctx); let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes(); let ndarray_dims = ndarray.dim_sizes();
let index = ctx.builder let index = ctx
.builder
.build_call( .build_call(
ndarray_flatten_index_fn, ndarray_flatten_index_fn,
&[ &[
@ -784,16 +754,11 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>(
ndarray: NDArrayValue<'ctx>, ndarray: NDArrayValue<'ctx>,
indices: &Index, indices: &Index,
) -> IntValue<'ctx> ) -> IntValue<'ctx>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>, { Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl( call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
generator,
ctx,
ndarray,
indices,
)
} }
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of /// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
@ -810,22 +775,23 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast", 32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64", 64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw) bw => unreachable!("Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let ndarray_calc_broadcast_fn =
let fn_type = llvm_usize.fn_type( ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
&[ let fn_type = llvm_usize.fn_type(
llvm_pusize.into(), &[
llvm_usize.into(), llvm_pusize.into(),
llvm_pusize.into(), llvm_usize.into(),
llvm_usize.into(), llvm_pusize.into(),
llvm_pusize.into(), llvm_usize.into(),
], llvm_pusize.into(),
false, ],
); false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
}); });
let lhs_ndims = lhs.load_ndims(ctx); let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx); let rhs_ndims = rhs.load_ndims(ctx);
@ -846,36 +812,22 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
}; };
let llvm_usize_const_one = llvm_usize.const_int(1, false); let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx.builder.build_int_compare( let lhs_eqz = ctx
IntPredicate::EQ, .builder
lhs_dim_sz, .build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
llvm_usize_const_one, .unwrap();
"", let rhs_eqz = ctx
).unwrap(); .builder
let rhs_eqz = ctx.builder.build_int_compare( .build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
IntPredicate::EQ, .unwrap();
rhs_dim_sz, let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
llvm_usize_const_one,
"",
).unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(
lhs_eqz,
rhs_eqz,
""
).unwrap();
let lhs_eq_rhs = ctx.builder.build_int_compare( let lhs_eq_rhs = ctx
IntPredicate::EQ, .builder
lhs_dim_sz, .build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
rhs_dim_sz, .unwrap();
""
).unwrap();
let is_compatible = ctx.builder.build_or( let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
lhs_or_rhs_eqz,
lhs_eq_rhs,
""
).unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
@ -889,7 +841,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
Ok(()) Ok(())
}, },
llvm_usize.const_int(1, false), llvm_usize.const_int(1, false),
).unwrap(); )
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None); let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator); let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
@ -923,7 +876,11 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`] /// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted /// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`. /// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>( pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G, generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>, ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>, array: NDArrayValue<'ctx>,
@ -937,21 +894,17 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() { let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx", 32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64", 64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw) bw => unreachable!("Unsupported size type bit width: {}", bw),
}; };
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| { let ndarray_calc_broadcast_fn =
let fn_type = llvm_usize.fn_type( ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
&[ let fn_type = llvm_usize.fn_type(
llvm_pusize.into(), &[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
llvm_usize.into(), false,
llvm_pi32.into(), );
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None) ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
}); });
let broadcast_size = broadcast_idx.size(ctx, generator); let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap(); let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
@ -959,23 +912,13 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
let array_dims = array.dim_sizes().base_ptr(ctx, generator); let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx); let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe { let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked( broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
ctx,
generator,
&llvm_usize.const_zero(),
None
)
}; };
ctx.builder ctx.builder
.build_call( .build_call(
ndarray_calc_broadcast_fn, ndarray_calc_broadcast_fn,
&[ &[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.into(),
],
"", "",
) )
.unwrap(); .unwrap();
@ -985,4 +928,4 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
Box::new(|_, v| v.into_int_value()), Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()), Box::new(|_, v| v.into()),
) )
} }

View File

@ -1,35 +1,35 @@
use inkwell::AddressSpace; use crate::codegen::CodeGenContext;
use inkwell::context::Context; use inkwell::context::Context;
use inkwell::intrinsics::Intrinsic; use inkwell::intrinsics::Intrinsic;
use inkwell::types::AnyTypeEnum::IntType; use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType; use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue}; use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use inkwell::AddressSpace;
use itertools::Either; use itertools::Either;
use crate::codegen::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic /// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions. /// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str { fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types // Standard LLVM floating-point types
if ft == ctx.f16_type() { if ft == ctx.f16_type() {
return "f16" return "f16";
} }
if ft == ctx.f32_type() { if ft == ctx.f32_type() {
return "f32" return "f32";
} }
if ft == ctx.f64_type() { if ft == ctx.f64_type() {
return "f64" return "f64";
} }
if ft == ctx.f128_type() { if ft == ctx.f128_type() {
return "f128" return "f128";
} }
// Non-standard floating-point types // Non-standard floating-point types
if ft == ctx.x86_f80_type() { if ft == ctx.x86_f80_type() {
return "f80" return "f80";
} }
if ft == ctx.ppc_f128_type() { if ft == ctx.ppc_f128_type() {
return "ppcf128" return "ppcf128";
} }
unreachable!() unreachable!()
@ -69,9 +69,7 @@ pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()])) .and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap(); .unwrap();
ctx.builder ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
.build_call(intrinsic_fn, &[ptr.into()], "")
.unwrap();
} }
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic. /// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
@ -232,10 +230,12 @@ pub fn call_memcpy<'ctx>(
let llvm_len_t = len.get_type(); let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME) let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration( .and_then(|intrinsic| {
&ctx.module, intrinsic.get_declaration(
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()], &ctx.module,
)) &[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap(); .unwrap();
ctx.builder ctx.builder
@ -315,10 +315,9 @@ pub fn call_float_powi<'ctx>(
let llvm_power_t = power.get_type(); let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME) let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration( .and_then(|intrinsic| {
&ctx.module, intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
&[llvm_val_t.into(), llvm_power_t.into()], })
))
.unwrap(); .unwrap();
ctx.builder ctx.builder
@ -442,7 +441,6 @@ pub fn call_float_exp2<'ctx>(
.unwrap() .unwrap()
} }
/// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic. /// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic.
pub fn call_float_log<'ctx>( pub fn call_float_log<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,
@ -672,7 +670,7 @@ pub fn call_float_round<'ctx>(
.unwrap() .unwrap()
} }
/// Invokes the /// Invokes the
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic. /// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
pub fn call_float_roundeven<'ctx>( pub fn call_float_roundeven<'ctx>(
ctx: &CodeGenContext<'ctx, '_>, ctx: &CodeGenContext<'ctx, '_>,

View File

@ -1,12 +1,7 @@
use crate::{ use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType}, codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver}, symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_var_tys,
TopLevelContext,
TopLevelDef,
},
typecheck::{ typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore}, type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier}, typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -14,24 +9,22 @@ use crate::{
}; };
use crossbeam::channel::{unbounded, Receiver, Sender}; use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{ use inkwell::{
AddressSpace,
IntPredicate,
OptimizationLevel,
attributes::{Attribute, AttributeLoc}, attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock, basic_block::BasicBlock,
builder::Builder, builder::Builder,
context::Context, context::Context,
debug_info::{
AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder,
},
module::Module, module::Module,
passes::PassBuilderOptions, passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple}, targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum}, types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue}, values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
debug_info::{ AddressSpace, IntPredicate, OptimizationLevel,
DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope
},
}; };
use itertools::Itertools; use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef, Location}; use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex}; use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::{ use std::sync::{
@ -91,7 +84,6 @@ pub struct CodeGenTargetMachineOptions {
} }
impl CodeGenTargetMachineOptions { impl CodeGenTargetMachineOptions {
/// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine. /// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults. /// Other options are set to defaults.
#[must_use] #[must_use]
@ -120,13 +112,11 @@ impl CodeGenTargetMachineOptions {
/// ///
/// See [`Target::create_target_machine`]. /// See [`Target::create_target_machine`].
#[must_use] #[must_use]
pub fn create_target_machine( pub fn create_target_machine(&self, level: OptimizationLevel) -> Option<TargetMachine> {
&self,
level: OptimizationLevel,
) -> Option<TargetMachine> {
let triple = TargetTriple::create(self.triple.as_str()); let triple = TargetTriple::create(self.triple.as_str());
let target = Target::from_triple(&triple) let target = Target::from_triple(&triple).unwrap_or_else(|_| {
.unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple)); panic!("could not create target from target triple {}", self.triple)
});
target.create_target_machine( target.create_target_machine(
&triple, &triple,
@ -134,7 +124,7 @@ impl CodeGenTargetMachineOptions {
self.features.as_str(), self.features.as_str(),
level, level,
self.reloc_mode, self.reloc_mode,
self.code_model self.code_model,
) )
} }
} }
@ -205,7 +195,6 @@ pub struct CodeGenContext<'ctx, 'a> {
} }
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder` /// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator]. /// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool { pub fn is_terminated(&self) -> bool {
@ -251,7 +240,6 @@ pub struct WorkerRegistry {
} }
impl WorkerRegistry { impl WorkerRegistry {
/// Creates workers for this registry. /// Creates workers for this registry.
#[must_use] #[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>( pub fn create_workers<G: CodeGenerator + Send + 'static>(
@ -373,7 +361,11 @@ impl WorkerRegistry {
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
self.wait_condvar.notify_all(); self.wait_condvar.notify_all();
} }
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n")); assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
let result = module.verify(); let result = module.verify();
if let Err(err) = result { if let Err(err) = result {
@ -386,13 +378,20 @@ impl WorkerRegistry {
.llvm_options .llvm_options
.target .target
.create_target_machine(self.llvm_options.opt_level) .create_target_machine(self.llvm_options.opt_level)
.unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target)); .unwrap_or_else(|| {
panic!(
"could not create target machine from properties {:?}",
self.llvm_options.target
)
});
let passes = format!("default<O{}>", self.llvm_options.opt_level as u32); let passes = format!("default<O{}>", self.llvm_options.opt_level as u32);
let result = module.run_passes(passes.as_str(), &target_machine, pass_options); let result = module.run_passes(passes.as_str(), &target_machine, pass_options);
if let Err(err) = result { if let Err(err) = result {
panic!("Failed to run optimization for module `{}`: {}", panic!(
module.get_name().to_str().unwrap(), "Failed to run optimization for module `{}`: {}",
err.to_string()); module.get_name().to_str().unwrap(),
err.to_string()
);
} }
f.run(&module); f.run(&module);
@ -436,9 +435,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let result = match &*ty_enum { let result = match &*ty_enum {
TObj { obj_id, fields, .. } => { TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes // check to avoid treating non-class primitives as classes
if obj_id.0 <= PRIMITIVE_DEF_IDS.max_id().0 { if PrimDef::contains_id(*obj_id) {
return match &*unifier.get_ty_immutable(ty) { return match &*unifier.get_ty_immutable(ty) {
TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.option => { TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
get_llvm_type( get_llvm_type(
ctx, ctx,
module, module,
@ -452,23 +451,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
.into() .into()
} }
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type( let element_type = get_llvm_type(
ctx, ctx, module, generator, unifier, top_level, type_cache, dtype,
module,
generator,
unifier,
top_level,
type_cache,
dtype,
); );
NDArrayType::new(generator, ctx, element_type).as_base_type().into() NDArrayType::new(generator, ctx, element_type).as_base_type().into()
} }
_ => unreachable!("LLVM type for primitive {} is missing", unifier.stringify(ty)), _ => unreachable!(
} "LLVM type for primitive {} is missing",
unifier.stringify(ty)
),
};
} }
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read(); let top_level_defs = top_level.definitions.read();
@ -484,7 +480,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let struct_type = ctx.opaque_struct_type(&name); let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert( type_cache.insert(
unifier.get_representative(ty), unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into(),
); );
let fields = fields_list let fields = fields_list
.iter() .iter()
@ -503,24 +499,21 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
struct_type.set_body(&fields, false); struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into() struct_type.ptr_type(AddressSpace::default()).into()
}; };
return ty return ty;
} }
TTuple { ty } => { TTuple { ty } => {
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| { .map(|ty| {
get_llvm_type( get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
ctx, module, generator, unifier, top_level, type_cache, *ty,
)
}) })
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ctx.struct_type(&fields, false).into()
} }
TList { ty } => { TList { ty } => {
let element_type = get_llvm_type( let element_type =
ctx, module, generator, unifier, top_level, type_cache, *ty, get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
);
ListType::new(generator, ctx, element_type).as_base_type().into() ListType::new(generator, ctx, element_type).as_base_type().into()
} }
@ -558,7 +551,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx.bool_type().into() ctx.bool_type().into()
} else { } else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty) get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
} };
} }
/// Whether `sret` is needed for a return value with type `ty`. /// Whether `sret` is needed for a return value with type `ty`.
@ -574,8 +567,9 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
match ty { match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false, BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false, BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => {
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)), ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false))
}
_ => true, _ => true,
} }
} }
@ -583,14 +577,18 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
} }
/// Implementation for generating LLVM IR for a function. /// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> ( pub fn gen_func_impl<
'ctx,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
>(
context: &'ctx Context, context: &'ctx Context,
generator: &mut G, generator: &mut G,
registry: &WorkerRegistry, registry: &WorkerRegistry,
builder: Builder<'ctx>, builder: Builder<'ctx>,
module: Module<'ctx>, module: Module<'ctx>,
task: CodeGenTask, task: CodeGenTask,
codegen_function: F codegen_function: F,
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> { ) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone(); let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone(); let static_value_store = registry.static_value_store.clone();
@ -654,7 +652,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str_type.set_body(&fields, false); str_type.set_body(&fields, false);
str_type.into() str_type.into()
} }
Some(t) => t.as_basic_type_enum() Some(t) => t.as_basic_type_enum(),
} }
}), }),
(primitives.range, RangeType::new(context).as_base_type().into()), (primitives.range, RangeType::new(context).as_base_type().into()),
@ -671,7 +669,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
exception.set_body(&fields, false); exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum() exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
} }
}) }),
] ]
.iter() .iter()
.copied() .copied()
@ -679,8 +677,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// NOTE: special handling of option cannot use this type cache since it contains type var, // NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead // handled inside get_llvm_type instead
let ConcreteTypeEnum::TFunc { args, ret, .. } = let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else {
task.store.get(task.signature) else {
unreachable!() unreachable!()
}; };
@ -697,7 +694,16 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let ret_type = if unifier.unioned(ret, primitives.none) { let ret_type = if unifier.unioned(ret, primitives.none) {
None None
} else { } else {
Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) Some(get_llvm_abi_type(
context,
&module,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
ret,
))
}; };
let has_sret = ret_type.map_or(false, |ty| need_sret(ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
@ -724,7 +730,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let fn_type = match ret_type { let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false), Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false) _ => context.void_type().fn_type(&params, false),
}; };
let symbol = &task.symbol_name; let symbol = &task.symbol_name;
@ -739,9 +745,13 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
fn_val.set_personality_function(personality); fn_val.set_personality_function(personality);
} }
if has_sret { if has_sret {
fn_val.add_attribute(AttributeLoc::Param(0), fn_val.add_attribute(
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"), AttributeLoc::Param(0),
ret_type.unwrap().as_any_type_enum())); context.create_type_attribute(
Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum(),
),
);
} }
let init_bb = context.append_basic_block(fn_val, "init"); let init_bb = context.append_basic_block(fn_val, "init");
@ -761,9 +771,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&mut type_cache, &mut type_cache,
arg.ty, arg.ty,
); );
let alloca = builder let alloca =
.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())) builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
.unwrap();
// Remap boolean parameters into i8 // Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() { let param = if local_type.is_int_type() && param.is_int_value() {
@ -774,7 +783,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
bool_to_i8(&builder, context, param_val) bool_to_i8(&builder, context, param_val)
} else { } else {
param_val param_val
}.into() }
.into()
} else { } else {
param param
}; };
@ -808,10 +818,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&task &task
.body .body
.first() .first()
.map_or_else( .map_or_else(|| "<nac3_internal>".to_string(), |f| f.location.file.0.to_string()),
|| "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(),
),
/* directory */ "", /* directory */ "",
/* producer */ "NAC3", /* producer */ "NAC3",
/* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None, /* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None,
@ -884,10 +891,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
row as u32, row as u32,
col as u32, col as u32,
func_scope.as_debug_info_scope(), func_scope.as_debug_info_scope(),
None None,
); );
code_gen_context.builder.set_current_debug_location(loc); code_gen_context.builder.set_current_debug_location(loc);
let result = codegen_function(generator, &mut code_gen_context); let result = codegen_function(generator, &mut code_gen_context);
// after static analysis, only void functions can have no return at the end. // after static analysis, only void functions can have no return at the end.
@ -949,7 +956,7 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
fn bool_to_i8<'ctx>( fn bool_to_i8<'ctx>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
ctx: &'ctx Context, ctx: &'ctx Context,
bool_value: IntValue<'ctx> bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let value_bits = bool_value.get_type().get_bit_width(); let value_bits = bool_value.get_type().get_bit_width();
match value_bits { match value_bits {
@ -965,7 +972,7 @@ fn bool_to_i8<'ctx>(
bool_value.get_type().const_zero(), bool_value.get_type().const_zero(),
"", "",
) )
.unwrap() .unwrap(),
), ),
} }
} }
@ -991,11 +998,18 @@ fn gen_in_range_check<'ctx>(
stop: IntValue<'ctx>, stop: IntValue<'ctx>,
step: IntValue<'ctx>, step: IntValue<'ctx>,
) -> IntValue<'ctx> { ) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "").unwrap(); let sign = ctx
let lo = ctx.builder.build_select(sign, value, stop, "") .builder
.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "")
.unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let hi = ctx.builder.build_select(sign, stop, value, "") let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();

File diff suppressed because it is too large Load Diff

View File

@ -10,12 +10,7 @@ use crate::{
expr::gen_binop_expr, expr::gen_binop_expr,
gen_in_range_check, gen_in_range_check,
}, },
toplevel::{ toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_var_tys,
TopLevelDef,
},
typecheck::typedef::{FunSignature, Type, TypeEnum}, typecheck::typedef::{FunSignature, Type, TypeEnum},
}; };
use inkwell::{ use inkwell::{
@ -116,13 +111,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
ctx.var_assignment.insert(*id, (ptr, None, counter)); ctx.var_assignment.insert(*id, (ptr, None, counter));
ptr ptr
} }
} },
ExprKind::Attribute { value, attr, .. } => { ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr); let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = if let Some(v) = generator.gen_expr(ctx, value)? { let val = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())? v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
} else { } else {
return Ok(None) return Ok(None);
}; };
let BasicValueEnum::PointerValue(ptr) = val else { let BasicValueEnum::PointerValue(ptr) = val else {
unreachable!(); unreachable!();
@ -136,7 +131,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
], ],
name.unwrap_or(""), name.unwrap_or(""),
) )
}.unwrap() }
.unwrap()
} }
ExprKind::Subscript { value, slice, .. } => { ExprKind::Subscript { value, slice, .. } => {
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
@ -153,11 +149,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
.unwrap() .unwrap()
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())? .to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value(); .into_int_value();
let raw_index = ctx.builder let raw_index = ctx
.builder
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext") .build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
.unwrap(); .unwrap();
// handle negative index // handle negative index
let is_negative = ctx.builder let is_negative = ctx
.builder
.build_int_compare( .build_int_compare(
IntPredicate::SLT, IntPredicate::SLT,
raw_index, raw_index,
@ -173,13 +171,9 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
.unwrap(); .unwrap();
// unsigned less than is enough, because negative index after adjustment is // unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp) // bigger than the length (for unsigned cmp)
let bound_check = ctx.builder let bound_check = ctx
.build_int_compare( .builder
IntPredicate::ULT, .build_int_compare(IntPredicate::ULT, index, len, "inbound")
index,
len,
"inbound",
)
.unwrap(); .unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
@ -192,7 +186,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
v.data().ptr_offset(ctx, generator, &index, name) v.data().ptr_offset(ctx, generator, &index, name)
} }
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!() todo!()
} }
@ -215,7 +209,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
match &target.node { match &target.node {
ExprKind::Tuple { elts, .. } => { ExprKind::Tuple { elts, .. } => {
let BasicValueEnum::StructValue(v) = let BasicValueEnum::StructValue(v) =
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? else { value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
else {
unreachable!() unreachable!()
}; };
@ -230,9 +225,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ExprKind::Subscript { value: ls, slice, .. } ExprKind::Subscript { value: ls, slice, .. }
if matches!(&slice.node, ExprKind::Slice { .. }) => if matches!(&slice.node, ExprKind::Slice { .. }) =>
{ {
let ExprKind::Slice { lower, upper, step } = &slice.node else { let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
unreachable!()
};
let ls = generator let ls = generator
.gen_expr(ctx, ls)? .gen_expr(ctx, ls)?
@ -240,21 +233,18 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())? .to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value(); .into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None); let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) = handle_slice_indices( let Some((start, end, step)) =
lower, handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
upper, else {
step, return Ok(());
ctx, };
generator,
ls.load_size(ctx, None),
)? else { return Ok(()) };
let value = value let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())? .to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value(); .into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None); let value = ListValue::from_ptr_val(value, llvm_usize, None);
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty, TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
} }
_ => unreachable!(), _ => unreachable!(),
@ -268,7 +258,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ctx, ctx,
generator, generator,
value.load_size(ctx, None), value.load_size(ctx, None),
)? else { return Ok(()) }; )?
else {
return Ok(());
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind); list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
} }
_ => { _ => {
@ -278,7 +271,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
String::from("target.addr") String::from("target.addr")
}; };
let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else { let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else {
return Ok(()) return Ok(());
}; };
if let ExprKind::Name { id, .. } = &target.node { if let ExprKind::Name { id, .. } = &target.node {
@ -301,9 +294,7 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() };
unreachable!()
};
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -316,11 +307,8 @@ pub fn gen_for<G: CodeGenerator>(
let body_bb = ctx.ctx.append_basic_block(current, "for.body"); let body_bb = ctx.ctx.append_basic_block(current, "for.body");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// if there is no orelse, we just go to cont_bb // if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() { let orelse_bb =
cont_bb if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
} else {
ctx.ctx.append_basic_block(current, "for.orelse")
};
// Whether the iterable is a range() expression // Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range); let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
@ -334,20 +322,17 @@ pub fn gen_for<G: CodeGenerator>(
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? { let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum( v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
ctx,
generator,
iter.custom.unwrap(),
)?
} else { } else {
return Ok(()) return Ok(());
}; };
if is_iterable_range_expr { if is_iterable_range_expr {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range")); let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned // Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?; let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed // Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? else { let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else {
unreachable!() unreachable!()
}; };
let (start, stop, step) = destructure_range(ctx, iter_val); let (start, stop, step) = destructure_range(ctx, iter_val);
@ -355,16 +340,15 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.build_store(i, start).unwrap(); ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised." // Check "If step is zero, ValueError is raised."
let rangenez = ctx.builder let rangenez =
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "") ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
.unwrap();
ctx.make_assert( ctx.make_assert(
generator, generator,
rangenez, rangenez,
"ValueError", "ValueError",
"range() arg 3 must not be zero", "range() arg 3 must not be zero",
[None, None, None], [None, None, None],
ctx.current_loc ctx.current_loc,
); );
ctx.builder.build_unconditional_branch(cond_bb).unwrap(); ctx.builder.build_unconditional_branch(cond_bb).unwrap();
@ -385,7 +369,8 @@ pub fn gen_for<G: CodeGenerator>(
} }
ctx.builder.position_at_end(incr_bb); ctx.builder.position_at_end(incr_bb);
let next_i = ctx.builder let next_i = ctx
.builder
.build_int_add( .build_int_add(
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(), ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
step, step,
@ -410,13 +395,14 @@ pub fn gen_for<G: CodeGenerator>(
.build_gep_and_load( .build_gep_and_load(
iter_val.into_pointer_value(), iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)], &[zero, int32.const_int(1, false)],
Some("len") Some("len"),
) )
.into_int_value(); .into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap(); ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(cond_bb); ctx.builder.position_at_end(cond_bb);
let index = ctx.builder let index = ctx
.builder
.build_load(index_addr, "for.index") .build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
@ -424,7 +410,8 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap(); ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
ctx.builder.position_at_end(incr_bb); ctx.builder.position_at_end(incr_bb);
let index = ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap(); let index =
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap(); let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
ctx.builder.build_store(index_addr, inc).unwrap(); ctx.builder.build_store(index_addr, inc).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap(); ctx.builder.build_unconditional_branch(cond_bb).unwrap();
@ -433,7 +420,8 @@ pub fn gen_for<G: CodeGenerator>(
let arr_ptr = ctx let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr")) .build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.into_pointer_value(); .into_pointer_value();
let index = ctx.builder let index = ctx
.builder
.build_load(index_addr, "for.index") .build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
@ -496,13 +484,13 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
body: BodyFn, body: BodyFn,
update: UpdateFn, update: UpdateFn,
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
I: Clone, I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>, InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{ {
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init"); let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init");
@ -528,9 +516,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
let cond = cond(generator, ctx, loop_var.clone())?; let cond = cond(generator, ctx, loop_var.clone())?;
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width()); assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
if !ctx.is_terminated() { if !ctx.is_terminated() {
ctx.builder ctx.builder.build_conditional_branch(cond, body_bb, cont_bb).unwrap();
.build_conditional_branch(cond, body_bb, cont_bb)
.unwrap();
} }
ctx.builder.position_at_end(body_bb); ctx.builder.position_at_end(body_bb);
@ -551,7 +537,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
Ok(()) Ok(())
} }
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the /// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
/// following C code: /// following C code:
/// ///
/// ```c /// ```c
@ -560,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// } /// }
/// ``` /// ```
/// ///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable. /// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
/// value should be treated as inclusive (as opposed to exclusive). /// value should be treated as inclusive (as opposed to exclusive).
@ -574,9 +560,9 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
body: BodyFn, body: BodyFn,
incr_val: IntValue<'ctx>, incr_val: IntValue<'ctx>,
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{ {
let init_val_t = init_val.get_type(); let init_val_t = init_val.get_type();
@ -590,38 +576,23 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
Ok(i_addr) Ok(i_addr)
}, },
|_, ctx, i_addr| { |_, ctx, i_addr| {
let cmp_op = if max_val.1 { let cmp_op = if max_val.1 { IntPredicate::ULE } else { IntPredicate::ULT };
IntPredicate::ULE
} else {
IntPredicate::ULT
};
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "") let max_val =
.map(BasicValueEnum::into_int_value) ctx.builder.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "").unwrap();
.unwrap();
let max_val = ctx.builder
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
.unwrap();
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap()) Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
}, },
|generator, ctx, i_addr| { |generator, ctx, i_addr| {
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body(generator, ctx, i) body(generator, ctx, i)
}, },
|_, ctx, i_addr| { |_, ctx, i_addr| {
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "") let incr_val =
.map(BasicValueEnum::into_int_value) ctx.builder.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "").unwrap();
.unwrap();
let incr_val = ctx.builder
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
.unwrap();
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap(); let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap(); ctx.builder.build_store(i_addr, i).unwrap();
@ -632,21 +603,21 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
/// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following /// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following
/// C code: /// C code:
/// ///
/// ```c /// ```c
/// bool incr = start_fn() <= end_fn(); /// bool incr = start_fn() <= end_fn();
/// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) { /// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) {
/// body_fn(i); /// body_fn(i);
/// } /// }
/// ``` /// ```
/// ///
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned. /// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like /// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// iterable. /// iterable.
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like /// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive. /// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like /// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`. /// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body. /// - `body_fn`: A lambda of IR statements within the loop body.
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>( pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
@ -658,16 +629,14 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
step_fn: StepFn, step_fn: StepFn,
body_fn: BodyFn, body_fn: BodyFn,
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>, BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{ {
let init_val_t = start_fn(generator, ctx) let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap();
.map(IntValue::get_type)
.unwrap();
gen_for_callback( gen_for_callback(
generator, generator,
@ -688,12 +657,15 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap() ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
}; };
let incr = ctx.builder.build_int_compare( let incr = ctx
if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE }, .builder
start, .build_int_compare(
stop, if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE },
"", start,
).unwrap(); stop,
"",
)
.unwrap();
Ok((i_addr, incr)) Ok((i_addr, incr))
}, },
@ -705,10 +677,7 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
(false, false) => (IntPredicate::SLT, IntPredicate::SGT), (false, false) => (IntPredicate::SLT, IntPredicate::SGT),
}; };
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let stop = stop_fn(generator, ctx)?; let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() { let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
stop stop
@ -718,14 +687,11 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap() ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
}; };
let i_lt_end = ctx.builder let i_lt_end = ctx.builder.build_int_compare(lt_cmp_op, i, stop, "").unwrap();
.build_int_compare(lt_cmp_op, i, stop, "") let i_gt_end = ctx.builder.build_int_compare(gt_cmp_op, i, stop, "").unwrap();
.unwrap();
let i_gt_end = ctx.builder
.build_int_compare(gt_cmp_op, i, stop, "")
.unwrap();
let cond = ctx.builder let cond = ctx
.builder
.build_select(incr, i_lt_end, i_gt_end, "") .build_select(incr, i_lt_end, i_gt_end, "")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
@ -733,18 +699,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
Ok(cond) Ok(cond)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, (i_addr, _)| {
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
body_fn(generator, ctx, i) body_fn(generator, ctx, i)
}, },
|generator, ctx, (i_addr, _)| { |generator, ctx, (i_addr, _)| {
let i = ctx.builder let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = step_fn(generator, ctx)?; let incr_val = step_fn(generator, ctx)?;
let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() { let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
@ -769,9 +729,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() };
unreachable!()
};
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -782,8 +740,11 @@ pub fn gen_while<G: CodeGenerator>(
let body_bb = ctx.ctx.append_basic_block(current, "while.body"); let body_bb = ctx.ctx.append_basic_block(current, "while.body");
let cont_bb = ctx.ctx.append_basic_block(current, "while.cont"); let cont_bb = ctx.ctx.append_basic_block(current, "while.cont");
// if there is no orelse, we just go to cont_bb // if there is no orelse, we just go to cont_bb
let orelse_bb = let orelse_bb = if orelse.is_empty() {
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "while.orelse") }; cont_bb
} else {
ctx.ctx.append_basic_block(current, "while.orelse")
};
// store loop bb information and restore it later // store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((test_bb, cont_bb)); let loop_bb = ctx.loop_target.replace((test_bb, cont_bb));
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
@ -796,11 +757,9 @@ pub fn gen_while<G: CodeGenerator>(
ctx.builder.build_unreachable().unwrap(); ctx.builder.build_unreachable().unwrap();
} }
return Ok(()) return Ok(());
};
let BasicValueEnum::IntValue(test) = test else {
unreachable!()
}; };
let BasicValueEnum::IntValue(test) = test else { unreachable!() };
ctx.builder ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -853,12 +812,12 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
then_fn: ThenFn, then_fn: ThenFn,
else_fn: ElseFn, else_fn: ElseFn,
) -> Result<Option<BasicValueEnum<'ctx>>, String> ) -> Result<Option<BasicValueEnum<'ctx>>, String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>, ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>, ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
R: BasicValue<'ctx>, R: BasicValue<'ctx>,
{ {
let current_bb = ctx.builder.get_insert_block().unwrap(); let current_bb = ctx.builder.get_insert_block().unwrap();
@ -893,8 +852,8 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
let phi = ctx.builder.build_phi(tv_ty, "").unwrap(); let phi = ctx.builder.build_phi(tv_ty, "").unwrap();
phi.add_incoming(&[(&tv, then_end_bb), (&ev, else_end_bb)]); phi.add_incoming(&[(&tv, then_end_bb), (&ev, else_end_bb)]);
Some(phi.as_basic_value()) Some(phi.as_basic_value())
}, }
(Some(tv), None) => Some(tv.as_basic_value_enum()), (Some(tv), None) => Some(tv.as_basic_value_enum()),
(None, Some(ev)) => Some(ev.as_basic_value_enum()), (None, Some(ev)) => Some(ev.as_basic_value_enum()),
(None, None) => None, (None, None) => None,
@ -919,11 +878,11 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
then_fn: ThenFn, then_fn: ThenFn,
else_fn: ElseFn, else_fn: ElseFn,
) -> Result<(), String> ) -> Result<(), String>
where where
G: CodeGenerator + ?Sized, G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>, CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>, ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
{ {
gen_if_else_expr_callback( gen_if_else_expr_callback(
generator, generator,
@ -936,7 +895,7 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
|generator, ctx| { |generator, ctx| {
else_fn(generator, ctx)?; else_fn(generator, ctx)?;
Ok(None) Ok(None)
} },
)?; )?;
Ok(()) Ok(())
@ -948,9 +907,7 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>, ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>, stmt: &Stmt<Option<Type>>,
) -> Result<(), String> { ) -> Result<(), String> {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() };
unreachable!()
};
// var_assignment static values may be changed in another branch // var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch // if so, remove the static value as it may not be correct in this branch
@ -969,9 +926,9 @@ pub fn gen_if<G: CodeGenerator>(
}; };
ctx.builder.build_unconditional_branch(test_bb).unwrap(); ctx.builder.build_unconditional_branch(test_bb).unwrap();
ctx.builder.position_at_end(test_bb); ctx.builder.position_at_end(test_bb);
let test = generator let test = generator.gen_expr(ctx, test).and_then(|v| {
.gen_expr(ctx, test) v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose()
.and_then(|v| v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose())?; })?;
if let Some(BasicValueEnum::IntValue(test)) = test { if let Some(BasicValueEnum::IntValue(test)) = test {
ctx.builder ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb) .build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -1077,16 +1034,16 @@ pub fn exn_constructor<'ctx>(
}; };
let defs = ctx.top_level.definitions.read(); let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read(); let def = defs[zelf_id].read();
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() };
unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe { unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap(); let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
let id = ctx.resolver.get_string_id(&exception_name); let id = ctx.resolver.get_string_id(&exception_name);
ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap(); ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap();
let empty_string = ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str); let empty_string =
let ptr = ctx.builder ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str);
let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg") .build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg")
.unwrap(); .unwrap();
let msg = if args.is_empty() { let msg = if args.is_empty() {
@ -1101,21 +1058,24 @@ pub fn exn_constructor<'ctx>(
} else { } else {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)? args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)?
}; };
let ptr = ctx.builder let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.param") .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.param")
.unwrap(); .unwrap();
ctx.builder.build_store(ptr, value).unwrap(); ctx.builder.build_store(ptr, value).unwrap();
} }
// set file, func to empty string // set file, func to empty string
for i in &[1, 4] { for i in &[1, 4] {
let ptr = ctx.builder let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.str") .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.str")
.unwrap(); .unwrap();
ctx.builder.build_store(ptr, empty_string.unwrap()).unwrap(); ctx.builder.build_store(ptr, empty_string.unwrap()).unwrap();
} }
// set ints to zero // set ints to zero
for i in &[2, 3] { for i in &[2, 3] {
let ptr = ctx.builder let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.ints") .build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.ints")
.unwrap(); .unwrap();
ctx.builder.build_store(ptr, zero).unwrap(); ctx.builder.build_store(ptr, zero).unwrap();
@ -1139,23 +1099,27 @@ pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
let int32 = ctx.ctx.i32_type(); let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero(); let zero = int32.const_zero();
let exception = exception.into_pointer_value(); let exception = exception.into_pointer_value();
let file_ptr = ctx.builder let file_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr") .build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr")
.unwrap(); .unwrap();
let filename = ctx.gen_string(generator, loc.file.0); let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename).unwrap(); ctx.builder.build_store(file_ptr, filename).unwrap();
let row_ptr = ctx.builder let row_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr") .build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr")
.unwrap(); .unwrap();
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap(); ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap();
let col_ptr = ctx.builder let col_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr") .build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr")
.unwrap(); .unwrap();
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap(); ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap();
let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap(); let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap()); let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap());
let name_ptr = ctx.builder let name_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr") .build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr")
.unwrap(); .unwrap();
ctx.builder.build_store(name_ptr, fun_name).unwrap(); ctx.builder.build_store(name_ptr, fun_name).unwrap();
@ -1204,7 +1168,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
let mut final_data = None; let mut final_data = None;
let has_cleanup = !finalbody.is_empty(); let has_cleanup = !finalbody.is_empty();
if has_cleanup { if has_cleanup {
let final_state = generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?; let final_state =
generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?;
final_data = Some((final_state, Vec::new(), Vec::new())); final_data = Some((final_state, Vec::new(), Vec::new()));
if let Some((continue_target, break_target)) = ctx.loop_target { if let Some((continue_target, break_target)) = ctx.loop_target {
let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break"); let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break");
@ -1219,8 +1184,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
} else { } else {
let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target"); let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target");
ctx.builder.position_at_end(return_target); ctx.builder.position_at_end(return_target);
let return_value = ctx.return_buffer let return_value =
.map(|v| ctx.builder.build_load(v, "$ret").unwrap()); ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret").unwrap());
ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)).unwrap(); ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)).unwrap();
ctx.builder.position_at_end(current_block); ctx.builder.position_at_end(current_block);
final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap()); final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap());
@ -1250,11 +1215,12 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
&mut ctx.unifier, &mut ctx.unifier,
type_.custom.unwrap(), type_.custom.unwrap(),
); );
let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) { let obj_id =
*obj_id if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
} else { *obj_id
unreachable!() } else {
}; unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name); let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name); let exn_id = ctx.resolver.get_string_id(&exception_name);
let exn_id_global = let exn_id_global =
@ -1303,16 +1269,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
// run end_catch before continue/break/return // run end_catch before continue/break/return
let mut final_proxy_lambda = let mut final_proxy_lambda =
|ctx: &mut CodeGenContext<'ctx, 'a>, |ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| {
target: BasicBlock<'ctx>, final_proxy(ctx, target, block, final_data.as_mut().unwrap());
block: BasicBlock<'ctx>| final_proxy(ctx, target, block, final_data.as_mut().unwrap()); };
let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>, let mut redirect_lambda =
target: BasicBlock<'ctx>, |ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| {
block: BasicBlock<'ctx>| { ctx.builder.position_at_end(block);
ctx.builder.position_at_end(block); ctx.builder.build_unconditional_branch(target).unwrap();
ctx.builder.build_unconditional_branch(target).unwrap(); ctx.builder.position_at_end(body);
ctx.builder.position_at_end(body); };
};
let redirect = if has_cleanup { let redirect = if has_cleanup {
&mut final_proxy_lambda &mut final_proxy_lambda
as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>) as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>)
@ -1357,12 +1322,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx.builder.position_at_end(dispatcher); ctx.builder.position_at_end(dispatcher);
unsafe { unsafe {
let zero = ctx.ctx.i32_type().const_zero(); let zero = ctx.ctx.i32_type().const_zero();
let exnid_ptr = ctx.builder let exnid_ptr = ctx
.build_gep( .builder
exn.as_basic_value().into_pointer_value(), .build_gep(exn.as_basic_value().into_pointer_value(), &[zero, zero], "exnidptr")
&[zero, zero],
"exnidptr",
)
.unwrap(); .unwrap();
Some(ctx.builder.build_load(exnid_ptr, "exnid").unwrap()) Some(ctx.builder.build_load(exnid_ptr, "exnid").unwrap())
} }
@ -1388,15 +1350,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
post_handlers.push(current); post_handlers.push(current);
ctx.builder.position_at_end(dispatcher_end); ctx.builder.position_at_end(dispatcher_end);
if let Some(exn_type) = exn_type { if let Some(exn_type) = exn_type {
let dispatcher_cont = let dispatcher_cont = ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont");
ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont");
let actual_id = exnid.unwrap().into_int_value(); let actual_id = exnid.unwrap().into_int_value();
let expected_id = ctx let expected_id = ctx
.builder .builder
.build_load(exn_type.into_pointer_value(), "expected_id") .build_load(exn_type.into_pointer_value(), "expected_id")
.map(BasicValueEnum::into_int_value) .map(BasicValueEnum::into_int_value)
.unwrap(); .unwrap();
let result = ctx.builder let result = ctx
.builder
.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck") .build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck")
.unwrap(); .unwrap();
ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont).unwrap(); ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont).unwrap();
@ -1522,11 +1484,9 @@ pub fn gen_return<G: CodeGenerator>(
let func = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap(); let func = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let value = if let Some(v_expr) = value.as_ref() { let value = if let Some(v_expr) = value.as_ref() {
if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() { if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() {
Some( Some(v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))?)
v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))?
)
} else { } else {
return Ok(()) return Ok(());
} }
} else { } else {
None None
@ -1554,7 +1514,8 @@ pub fn gen_return<G: CodeGenerator>(
generator.bool_to_i1(ctx, ret_val) generator.bool_to_i1(ctx, ret_val)
} else { } else {
ret_val ret_val
}.into() }
.into()
} else { } else {
ret_val ret_val
} }
@ -1592,16 +1553,12 @@ pub fn gen_stmt<G: CodeGenerator>(
} }
StmtKind::AnnAssign { target, value, .. } => { StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value { if let Some(value) = value {
let Some(value) = generator.gen_expr(ctx, value)? else { let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
return Ok(())
};
generator.gen_assign(ctx, target, value)?; generator.gen_assign(ctx, target, value)?;
} }
} }
StmtKind::Assign { targets, value, .. } => { StmtKind::Assign { targets, value, .. } => {
let Some(value) = generator.gen_expr(ctx, value)? else { let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
return Ok(())
};
for target in targets { for target in targets {
generator.gen_assign(ctx, target, value.clone())?; generator.gen_assign(ctx, target, value.clone())?;
} }
@ -1617,7 +1574,7 @@ pub fn gen_stmt<G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value = gen_binop_expr(generator, ctx, target, op, value, stmt.location, true)?; let value = gen_binop_expr(generator, ctx, target, *op, value, stmt.location, true)?;
generator.gen_assign(ctx, target, value.unwrap())?; generator.gen_assign(ctx, target, value.unwrap())?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
@ -1626,7 +1583,7 @@ pub fn gen_stmt<G: CodeGenerator>(
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? { let exc = if let Some(v) = generator.gen_expr(ctx, exc)? {
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())? v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
} else { } else {
return Ok(()) return Ok(());
}; };
gen_raise(generator, ctx, Some(&exc), stmt.location); gen_raise(generator, ctx, Some(&exc), stmt.location);
} else { } else {
@ -1637,14 +1594,16 @@ pub fn gen_stmt<G: CodeGenerator>(
let test = if let Some(v) = generator.gen_expr(ctx, test)? { let test = if let Some(v) = generator.gen_expr(ctx, test)? {
v.to_basic_value_enum(ctx, generator, test.custom.unwrap())? v.to_basic_value_enum(ctx, generator, test.custom.unwrap())?
} else { } else {
return Ok(()) return Ok(());
}; };
let err_msg = match msg { let err_msg = match msg {
Some(msg) => if let Some(v) = generator.gen_expr(ctx, msg)? { Some(msg) => {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())? if let Some(v) = generator.gen_expr(ctx, msg)? {
} else { v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
return Ok(()) } else {
}, return Ok(());
}
}
None => ctx.gen_string(generator, ""), None => ctx.gen_string(generator, ""),
}; };
ctx.make_assert_impl( ctx.make_assert_impl(
@ -1656,7 +1615,7 @@ pub fn gen_stmt<G: CodeGenerator>(
stmt.location, stmt.location,
); );
} }
_ => unimplemented!() _ => unimplemented!(),
}; };
Ok(()) Ok(())
} }

View File

@ -1,13 +1,14 @@
use crate::{ use crate::{
codegen::{ codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType}, classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenerator, CodeGenLLVMOptions, concrete_type::ConcreteTypeStore,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry, CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
}, },
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{ toplevel::{
composer::{ComposerConfig, TopLevelComposer}, DefinitionId, FunInstance, TopLevelContext, composer::{ComposerConfig, TopLevelComposer},
TopLevelDef, DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
}, },
typecheck::{ typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore}, type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
@ -17,7 +18,7 @@ use crate::{
use indoc::indoc; use indoc::indoc;
use inkwell::{ use inkwell::{
targets::{InitializationConfig, Target}, targets::{InitializationConfig, Target},
OptimizationLevel OptimizationLevel,
}; };
use nac3parser::{ use nac3parser::{
ast::{fold::Fold, StrRef}, ast::{fold::Fold, StrRef},
@ -70,9 +71,7 @@ impl SymbolResolver for Resolver {
.read() .read()
.get(&id) .get(&id)
.cloned() .cloned()
.ok_or_else(|| HashSet::from([ .ok_or_else(|| HashSet::from([format!("cannot find symbol `{}`", id)]))
format!("cannot find symbol `{}`", id),
]))
} }
fn get_string_id(&self, _: &str) -> i32 { fn get_string_id(&self, _: &str) -> i32 {
@ -227,12 +226,7 @@ fn test_primitives() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }
@ -417,12 +411,7 @@ fn test_simple_call() {
opt_level: OptimizationLevel::Default, opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(), target: CodeGenTargetMachineOptions::from_host_triple(),
}; };
let (registry, handles) = WorkerRegistry::create_workers( let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
threads,
top_level,
&llvm_options,
&f
);
registry.add_task(task); registry.add_task(task);
registry.wait_tasks_complete(handles); registry.wait_tasks_complete(handles);
} }

View File

@ -1,5 +1,23 @@
#![warn(clippy::all)] #![deny(
#![allow(dead_code)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
dead_code,
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
pub mod codegen; pub mod codegen;
pub mod symbol_resolver; pub mod symbol_resolver;

View File

@ -1,18 +1,18 @@
use std::fmt::Debug; use std::fmt::Debug;
use std::rc::Rc;
use std::sync::Arc; use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display}; use std::{collections::HashMap, collections::HashSet, fmt::Display};
use std::rc::Rc;
use crate::{ use crate::{
codegen::{CodeGenContext, CodeGenerator}, codegen::{CodeGenContext, CodeGenerator},
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation}, toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue}; use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, Itertools, izip}; use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef}; use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock; use parking_lot::RwLock;
@ -39,7 +39,7 @@ impl SymbolValue {
constant: &Constant, constant: &Constant,
expected_ty: Type, expected_ty: Type,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
unifier: &mut Unifier unifier: &mut Unifier,
) -> Result<Self, String> { ) -> Result<Self, String> {
match constant { match constant {
Constant::None => { Constant::None => {
@ -62,24 +62,16 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got str")) Err(format!("Expected {expected_ty:?}, but got str"))
} }
}, }
Constant::Int(i) => { Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) { if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i) i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) { } else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i) i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) { } else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i) u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) { } else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i) u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
} else { } else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty))) Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
} }
@ -87,7 +79,10 @@ impl SymbolValue {
Constant::Tuple(t) => { Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty); let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else { let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name())) return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
}; };
assert_eq!(ty.len(), t.len()); assert_eq!(ty.len(), t.len());
@ -105,7 +100,7 @@ impl SymbolValue {
} else { } else {
Err(format!("Expected {expected_ty:?}, but got float")) Err(format!("Expected {expected_ty:?}, but got float"))
} }
}, }
_ => Err(format!("Unsupported value type {constant:?}")), _ => Err(format!("Unsupported value type {constant:?}")),
} }
} }
@ -113,9 +108,7 @@ impl SymbolValue {
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value. /// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
/// ///
/// * `constant` - The constant to create the value from. /// * `constant` - The constant to create the value from.
pub fn from_constant_inferred( pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
constant: &Constant,
) -> Result<Self, String> {
match constant { match constant {
Constant::None => Ok(SymbolValue::OptionNone), Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)), Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
@ -123,13 +116,19 @@ impl SymbolValue {
Constant::Int(i) => { Constant::Int(i) => {
let i = *i; let i = *i;
if i >= 0 { if i >= 0 {
i32::try_from(i).map(SymbolValue::I32) i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64)) .or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) .map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else { } else {
u32::try_from(i).map(SymbolValue::U32) u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64)) .or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}")) .map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} }
} }
Constant::Tuple(t) => { Constant::Tuple(t) => {
@ -155,20 +154,19 @@ impl SymbolValue {
SymbolValue::Double(_) => primitives.float, SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool, SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => { SymbolValue::Tuple(vs) => {
let vs_tys = vs let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
.iter() unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
} }
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option, SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
} }
} }
/// Returns the [`TypeAnnotation`] representing the data type of this value. /// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self { match self {
SymbolValue::Bool(..) SymbolValue::Bool(..)
| SymbolValue::Double(..) | SymbolValue::Double(..)
@ -199,7 +197,11 @@ impl SymbolValue {
} }
/// Returns the [`TypeEnum`] representing the data type of this value. /// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> { pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier); let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty) unifier.get_ty(ty)
} }
@ -239,7 +241,7 @@ impl TryFrom<SymbolValue> for u64 {
match value { match value {
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()), SymbolValue::I32(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()), SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(v as u64), SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::U64(v) => Ok(v), SymbolValue::U64(v) => Ok(v),
_ => Err(()), _ => Err(()),
} }
@ -253,10 +255,10 @@ impl TryFrom<SymbolValue> for i128 {
/// numeric. /// numeric.
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> { fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value { match value {
SymbolValue::I32(v) => Ok(v as i128), SymbolValue::I32(v) => Ok(i128::from(v)),
SymbolValue::I64(v) => Ok(v as i128), SymbolValue::I64(v) => Ok(i128::from(v)),
SymbolValue::U32(v) => Ok(v as i128), SymbolValue::U32(v) => Ok(i128::from(v)),
SymbolValue::U64(v) => Ok(v as i128), SymbolValue::U64(v) => Ok(i128::from(v)),
_ => Err(()), _ => Err(()),
} }
} }
@ -332,7 +334,6 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
} }
impl<'ctx> ValueEnum<'ctx> { impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`]. /// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>( pub fn to_basic_value_enum<'a>(
self, self,
@ -374,7 +375,7 @@ pub trait SymbolResolver {
&self, &self,
_unifier: &mut Unifier, _unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>], _top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore _primitives: &PrimitiveStore,
) -> Result<(), String> { ) -> Result<(), String> {
Ok(()) Ok(())
} }
@ -443,40 +444,29 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Unexpected number of type parameters: expected {} but got 0",
"Unexpected number of type parameters: expected {} but got 0", type_vars.len()
type_vars.len() )]));
),
]))
} }
let fields = chain( let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))), fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))), methods.iter().map(|(k, v, _)| (*k, (*v, false))),
) )
.collect(); .collect();
Ok(unifier.add_ty(TypeEnum::TObj { Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
obj_id,
fields,
params: VarMap::default(),
}))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
format!("Cannot use function name as type at {loc}"),
]))
} }
} else { } else {
let ty = resolver let ty =
.get_symbol_type(unifier, top_level_defs, primitives, *id) resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
.map_err(|e| HashSet::from([ |e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
format!("Unknown type annotation at {loc}: {e}"), )?;
]))?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) { if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty) Ok(ty)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
format!("Unknown type annotation {id} at {loc}"),
]))
} }
} }
} }
@ -499,9 +489,7 @@ pub fn parse_type_annotation<T>(
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty })) Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Expected multiple elements for tuple".into()]))
"Expected multiple elements for tuple".into()
]))
} }
} else if *id == literal_id { } else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| { let mut parse_literal = |elt: &Expr<T>| {
@ -509,19 +497,21 @@ pub fn parse_type_annotation<T>(
let ty_enum = &*unifier.get_ty_immutable(ty); let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum { match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()), TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("Expected literal in type argument for Literal at {}", elt.location), "Expected literal in type argument for Literal at {}",
])) elt.location
)])),
} }
}; };
let values = if let Tuple { elts, .. } = &slice.node { let values = if let Tuple { elts, .. } = &slice.node {
elts.iter() elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
.map(&mut parse_literal)
.collect::<Result<Vec<_>, _>>()?
} else { } else {
vec![parse_literal(slice)?] vec![parse_literal(slice)?]
}.into_iter().flatten().collect_vec(); }
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location))) Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else { } else {
@ -539,13 +529,11 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read(); let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def { if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() { if types.len() != type_vars.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Unexpected number of type parameters: expected {} but got {}",
"Unexpected number of type parameters: expected {} but got {}", type_vars.len(),
type_vars.len(), types.len()
types.len() )]));
),
]))
} }
let mut subst = VarMap::new(); let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) { for (var, ty) in izip!(type_vars.iter(), types.iter()) {
@ -569,9 +557,7 @@ pub fn parse_type_annotation<T>(
})); }));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst })) Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["Cannot use function name as type".into()]))
"Cannot use function name as type".into(),
]))
} }
} }
}; };
@ -582,17 +568,13 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node { if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier) subscript_name_handle(id, slice, unifier)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
format!("unsupported type expression at {}", expr.location),
]))
} }
} }
Constant { value, .. } => SymbolValue::from_constant_inferred(value) Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location))) .map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])), .map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
format!("unsupported type expression at {}", expr.location),
])),
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -82,7 +82,8 @@ impl TopLevelComposer {
let mut builtin_id = HashMap::default(); let mut builtin_id = HashMap::default();
let mut builtin_ty = HashMap::default(); let mut builtin_ty = HashMap::default();
let builtin_name_list = definition_ast_list.iter() let builtin_name_list = definition_ast_list
.iter()
.map(|def_ast| match *def_ast.0.read() { .map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } => name.to_string(), TopLevelDef::Class { name, .. } => name.to_string(),
TopLevelDef::Function { simple_name, .. } => simple_name.to_string(), TopLevelDef::Function { simple_name, .. } => simple_name.to_string(),
@ -93,19 +94,24 @@ impl TopLevelComposer {
let name = (**name).into(); let name = (**name).into();
let def = definition_ast_list[id].0.read(); let def = definition_ast_list[id].0.read();
if let TopLevelDef::Function { name: func_name, simple_name, signature, .. } = &*def { if let TopLevelDef::Function { name: func_name, simple_name, signature, .. } = &*def {
assert_eq!(name, *simple_name, "Simple name of builtin function should match builtin name list"); assert_eq!(
name, *simple_name,
"Simple name of builtin function should match builtin name list"
);
// Do not add member functions into the list of builtin IDs; // Do not add member functions into the list of builtin IDs;
// Here we assume that all builtin top-level functions have the same name and simple // Here we assume that all builtin top-level functions have the same name and simple
// name, and all member functions have something prefixed to its name // name, and all member functions have something prefixed to its name
if *func_name != simple_name.to_string() { if *func_name != simple_name.to_string() {
continue continue;
} }
builtin_ty.insert(name, *signature); builtin_ty.insert(name, *signature);
builtin_id.insert(name, DefinitionId(id)); builtin_id.insert(name, DefinitionId(id));
} else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def } else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def {
{ assert_eq!(
assert_eq!(id, object_id.0, "Object id of class '{name}' should match its index in builtin name list"); id, object_id.0,
"Object id of class '{name}' should match its index in builtin name list"
);
if let Some(constructor) = constructor { if let Some(constructor) = constructor {
builtin_ty.insert(*name, *constructor); builtin_ty.insert(*name, *constructor);
} }
@ -384,9 +390,9 @@ impl TopLevelComposer {
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = { let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def { if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
let Some(ast::Located { let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) =
node: ast::StmtKind::ClassDef { bases, .. }, .. class_ast
}) = class_ast else { else {
unreachable!() unreachable!()
}; };
@ -415,12 +421,10 @@ impl TopLevelComposer {
} => } =>
{ {
if is_generic { if is_generic {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "only single Generic[...] is allowed (at {})",
"only single Generic[...] is allowed (at {})", b.location
b.location )]));
),
]))
} }
is_generic = true; is_generic = true;
@ -459,12 +463,10 @@ impl TopLevelComposer {
}) })
}; };
if !all_unique_type_var { if !all_unique_type_var {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "duplicate type variable occurs (at {})",
"duplicate type variable occurs (at {})", slice.location
slice.location )]));
),
]))
} }
// add to TopLevelDef // add to TopLevelDef
@ -487,7 +489,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
Ok(()) Ok(())
} }
@ -514,9 +516,9 @@ impl TopLevelComposer {
} = &mut *class_def } = &mut *class_def
{ {
let Some(ast::Located { let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, node: ast::StmtKind::ClassDef { bases, .. }, ..
.. }) = class_ast
}) = class_ast else { else {
unreachable!() unreachable!()
}; };
@ -543,13 +545,11 @@ impl TopLevelComposer {
} }
if has_base { if has_base {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "a class definition can only have at most one base class \
"a class definition can only have at most one base class \
declaration and one generic declaration (at {})", declaration and one generic declaration (at {})",
b.location b.location
), )]));
]))
} }
has_base = true; has_base = true;
@ -561,18 +561,18 @@ impl TopLevelComposer {
unifier, unifier,
&primitive_types, &primitive_types,
b, b,
vec![(*class_def_id, class_type_vars.clone())].into_iter().collect(), vec![(*class_def_id, class_type_vars.clone())]
.into_iter()
.collect::<HashMap<_, _>>(),
)?; )?;
if let TypeAnnotation::CustomClass { .. } = &base_ty { if let TypeAnnotation::CustomClass { .. } = &base_ty {
class_ancestors.push(base_ty); class_ancestors.push(base_ty);
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "class base declaration can only be custom class (at {})",
"class base declaration can only be custom class (at {})", b.location,
b.location, )]));
),
]))
} }
} }
Ok(()) Ok(())
@ -589,31 +589,35 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
// second, get all ancestors // second, get all ancestors
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default(); let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default();
let mut get_all_ancestors = |class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> { let mut get_all_ancestors =
let class_def = class_def.read(); |class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> {
let (class_ancestors, class_id) = { let class_def = class_def.read();
if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def { let (class_ancestors, class_id) = {
(ancestors, *object_id) if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
} else { (ancestors, *object_id)
return Ok(()); } else {
} return Ok(());
}
};
ancestors_store.insert(
class_id,
// if class has direct parents, get all ancestors of its parents. Else just empty
if class_ancestors.is_empty() {
vec![]
} else {
Self::get_all_ancestors_helper(
&class_ancestors[0],
temp_def_list.as_slice(),
)?
},
);
Ok(())
}; };
ancestors_store.insert(
class_id,
// if class has direct parents, get all ancestors of its parents. Else just empty
if class_ancestors.is_empty() {
vec![]
} else {
Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())?
},
);
Ok(())
};
for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) { for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) {
if ast.is_none() { if ast.is_none() {
continue; continue;
@ -623,7 +627,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
// insert the ancestors to the def list // insert the ancestors to the def list
@ -633,8 +637,7 @@ impl TopLevelComposer {
} }
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let (class_ancestors, class_id, class_type_vars) = { let (class_ancestors, class_id, class_type_vars) = {
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def
&mut *class_def
{ {
(ancestors, *object_id, type_vars) (ancestors, *object_id, type_vars)
} else { } else {
@ -665,8 +668,9 @@ impl TopLevelComposer {
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. } ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
) { ) {
return Err(HashSet::from([ return Err(HashSet::from([
"Classes inherited from exception should have no custom fields/methods".into() "Classes inherited from exception should have no custom fields/methods"
])) .into(),
]));
} }
} }
} }
@ -674,7 +678,8 @@ impl TopLevelComposer {
// deal with ancestor of Exception object // deal with ancestor of Exception object
let TopLevelDef::Class { name, ancestors, object_id, .. } = let TopLevelDef::Class { name, ancestors, object_id, .. } =
&mut *self.definition_ast_list[7].0.write() else { &mut *self.definition_ast_list[7].0.write()
else {
unreachable!() unreachable!()
}; };
@ -713,7 +718,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
// handle the inherited methods and fields // handle the inherited methods and fields
@ -758,9 +763,14 @@ impl TopLevelComposer {
let mut subst_list = Some(Vec::new()); let mut subst_list = Some(Vec::new());
// unification of previously assigned typevar // unification of previously assigned typevar
let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> { let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> {
let target_ty = let target_ty = get_type_from_type_annotation_kinds(
get_type_from_type_annotation_kinds(&temp_def_list, unifier, &def, &mut subst_list)?; &temp_def_list,
unifier.unify(ty, target_ty) unifier,
&def,
&mut subst_list,
)?;
unifier
.unify(ty, target_ty)
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?; .map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
Ok(()) Ok(())
}; };
@ -793,14 +803,16 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
for (def, _) in def_ast_list.iter().skip(self.builtin_num) { for (def, _) in def_ast_list.iter().skip(self.builtin_num) {
match &*def.read() { match &*def.read() {
TopLevelDef::Class { resolver: Some(resolver), .. } TopLevelDef::Class { resolver: Some(resolver), .. }
| TopLevelDef::Function { resolver: Some(resolver), .. } => { | TopLevelDef::Function { resolver: Some(resolver), .. } => {
if let Err(e) = resolver.handle_deferred_eval(unifier, &temp_def_list, primitives) { if let Err(e) =
resolver.handle_deferred_eval(unifier, &temp_def_list, primitives)
{
errors.insert(e); errors.insert(e);
} }
} }
@ -828,7 +840,8 @@ impl TopLevelComposer {
return Ok(()); return Ok(());
}; };
let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def else { let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def
else {
// not top level function def, skip // not top level function def, skip
return Ok(()); return Ok(());
}; };
@ -857,25 +870,22 @@ impl TopLevelComposer {
"top level function must have unique parameter names \ "top level function must have unique parameter names \
and names should not be the same as the keywords (at {})", and names should not be the same as the keywords (at {})",
x.location x.location
), )]));
])) }
}} }
let arg_with_default: Vec<( let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> =
&ast::Located<ast::ArgData<()>>, args.args
Option<&ast::Expr>, .iter()
)> = args .rev()
.args .zip(
.iter() args.defaults
.rev() .iter()
.zip( .rev()
args.defaults .map(|x| -> Option<&ast::Expr> { Some(x) })
.iter() .chain(std::iter::repeat(None)),
.rev() )
.map(|x| -> Option<&ast::Expr> { Some(x) }) .collect_vec();
.chain(std::iter::repeat(None)),
)
.collect_vec();
arg_with_default arg_with_default
.iter() .iter()
@ -885,12 +895,12 @@ impl TopLevelComposer {
.node .node
.annotation .annotation
.as_ref() .as_ref()
.ok_or_else(|| HashSet::from([ .ok_or_else(|| {
format!( HashSet::from([format!(
"function parameter `{}` needs type annotation at {}", "function parameter `{}` needs type annotation at {}",
x.node.arg, x.location x.node.arg, x.location
), )])
]))? })?
.as_ref(); .as_ref();
let type_annotation = parse_ast_to_type_annotation_kinds( let type_annotation = parse_ast_to_type_annotation_kinds(
@ -926,7 +936,7 @@ impl TopLevelComposer {
temp_def_list.as_ref(), temp_def_list.as_ref(),
unifier, unifier,
&type_annotation, &type_annotation,
&mut None &mut None,
)?; )?;
Ok(FuncArg { Ok(FuncArg {
@ -935,18 +945,16 @@ impl TopLevelComposer {
default_value: match default { default_value: match default {
None => None, None => None,
Some(default) => Some({ Some(default) => Some({
let v = Self::parse_parameter_default_value( let v = Self::parse_parameter_default_value(default, resolver)?;
default, resolver,
)?;
Self::check_default_param_type( Self::check_default_param_type(
&v, &v,
&type_annotation, &type_annotation,
primitives_store, primitives_store,
unifier, unifier,
) )
.map_err( .map_err(|err| {
|err| HashSet::from([format!("{} (at {})", err, x.location), HashSet::from([format!("{} (at {})", err, x.location)])
]))?; })?;
v v
}), }),
}, },
@ -993,7 +1001,7 @@ impl TopLevelComposer {
&temp_def_list, &temp_def_list,
unifier, unifier,
&return_ty_annotation, &return_ty_annotation,
&mut None &mut None,
)? )?
} else { } else {
primitives_store.none primitives_store.none
@ -1016,9 +1024,9 @@ impl TopLevelComposer {
ret: return_ty, ret: return_ty,
vars: function_var_map, vars: function_var_map,
})); }));
unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([ unifier.unify(*dummy_ty, function_ty).map_err(|e| {
e.at(Some(function_ast.location)).to_display(unifier).to_string(), HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()])
]))?; })?;
Ok(()) Ok(())
}; };
for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) { for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) {
@ -1030,7 +1038,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
Ok(()) Ok(())
} }
@ -1047,14 +1055,9 @@ impl TopLevelComposer {
let (keyword_list, core_config) = core_info; let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write(); let mut class_def = class_def.write();
let TopLevelDef::Class { let TopLevelDef::Class {
object_id, object_id, ancestors, fields, methods, resolver, type_vars, ..
ancestors, } = &mut *class_def
fields, else {
methods,
resolver,
type_vars,
..
} = &mut *class_def else {
unreachable!("here must be toplevel class def"); unreachable!("here must be toplevel class def");
}; };
let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else { let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else {
@ -1153,7 +1156,7 @@ impl TopLevelComposer {
annotation_expr, annotation_expr,
vec![(class_id, class_type_vars_def.clone())] vec![(class_id, class_type_vars_def.clone())]
.into_iter() .into_iter()
.collect(), .collect::<HashMap<_, _>>(),
)? )?
}; };
// find type vars within this method parameter type annotation // find type vars within this method parameter type annotation
@ -1218,7 +1221,9 @@ impl TopLevelComposer {
unifier, unifier,
primitives, primitives,
result, result,
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())]
.into_iter()
.collect::<HashMap<_, _>>(),
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1312,7 +1317,9 @@ impl TopLevelComposer {
unifier, unifier,
primitives, primitives,
annotation.as_ref(), annotation.as_ref(),
vec![(class_id, class_type_vars_def.clone())].into_iter().collect(), vec![(class_id, class_type_vars_def.clone())]
.into_iter()
.collect::<HashMap<_, _>>(),
)?; )?;
// find type vars within this return type annotation // find type vars within this return type annotation
let type_vars_within = let type_vars_within =
@ -1375,14 +1382,9 @@ impl TopLevelComposer {
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>, type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
let TopLevelDef::Class { let TopLevelDef::Class {
object_id, object_id, ancestors, fields, methods, resolver, type_vars, ..
ancestors, } = class_def
fields, else {
methods,
resolver,
type_vars,
..
} = class_def else {
unreachable!("here must be class def ast") unreachable!("here must be class def ast")
}; };
let ( let (
@ -1414,9 +1416,7 @@ impl TopLevelComposer {
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods { for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
// find if there is a method with same name in the child class // find if there is a method with same name in the child class
let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id); let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id);
for (class_method_name, class_method_ty, class_method_defid) in for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
&*class_methods_def
{
if class_method_name == anc_method_name { if class_method_name == anc_method_name {
// ignore and handle self // ignore and handle self
// if is __init__ method, no need to check return type // if is __init__ method, no need to check return type
@ -1430,27 +1430,20 @@ impl TopLevelComposer {
if !ok { if !ok {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"method {class_method_name} has same name as ancestors' method, but incompatible type"), "method {class_method_name} has same name as ancestors' method, but incompatible type"),
])) ]));
} }
// mark it as added // mark it as added
is_override.insert(*class_method_name); is_override.insert(*class_method_name);
to_be_added = to_be_added = (*class_method_name, *class_method_ty, *class_method_defid);
(*class_method_name, *class_method_ty, *class_method_defid);
break; break;
} }
} }
new_child_methods.push(to_be_added); new_child_methods.push(to_be_added);
} }
// add those that are not overriding method to the new_child_methods // add those that are not overriding method to the new_child_methods
for (class_method_name, class_method_ty, class_method_defid) in for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
&*class_methods_def
{
if !is_override.contains(class_method_name) { if !is_override.contains(class_method_name) {
new_child_methods.push(( new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid));
*class_method_name,
*class_method_ty,
*class_method_defid,
));
} }
} }
// use the new_child_methods to replace all the elements in `class_methods_def` // use the new_child_methods to replace all the elements in `class_methods_def`
@ -1466,8 +1459,8 @@ impl TopLevelComposer {
for (class_field_name, ..) in &*class_fields_def { for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name { if class_field_name == anc_field_name {
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes"), "field `{class_field_name}` has already declared in the ancestor classes"
])) )]));
} }
} }
new_child_fields.push(to_be_added); new_child_fields.push(to_be_added);
@ -1499,24 +1492,30 @@ impl TopLevelComposer {
// first, fix function typevar ids // first, fix function typevar ids
// they may be changed with our use of placeholders // they may be changed with our use of placeholders
for (def, _) in definition_ast_list.iter().skip(self.builtin_num) { for (def, _) in definition_ast_list.iter().skip(self.builtin_num) {
if let TopLevelDef::Function { if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() {
signature,
var_id,
..
} = &mut *def.write() {
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) = if let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
unifier.get_ty(*signature).as_ref() { unifier.get_ty(*signature).as_ref()
let new_var_ids = vars.values().map(|v| match &*unifier.get_ty(*v) { {
TypeEnum::TVar{id, ..} => *id, let new_var_ids = vars
_ => unreachable!(), .values()
}).collect_vec(); .map(|v| match &*unifier.get_ty(*v) {
TypeEnum::TVar { id, .. } => *id,
_ => unreachable!(),
})
.collect_vec();
if new_var_ids != *var_id { if new_var_ids != *var_id {
let new_signature = FunSignature { let new_signature = FunSignature {
args: args.clone(), args: args.clone(),
ret: *ret, ret: *ret,
vars: new_var_ids.iter().zip(vars.values()).map(|(id, v)| (*id, *v)).collect(), vars: new_var_ids
.iter()
.zip(vars.values())
.map(|(id, v)| (*id, *v))
.collect(),
}; };
unifier.unification_table.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature))); unifier
.unification_table
.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature)));
*var_id = new_var_ids; *var_id = new_var_ids;
} }
} }
@ -1542,7 +1541,7 @@ impl TopLevelComposer {
&def_list, &def_list,
unifier, unifier,
&make_self_type_annotation(type_vars, *object_id), &make_self_type_annotation(type_vars, *object_id),
&mut None &mut None,
)?; )?;
if ancestors if ancestors
.iter() .iter()
@ -1590,9 +1589,12 @@ impl TopLevelComposer {
}; };
constructors.push((i, signature, definition_extension.len())); constructors.push((i, signature, definition_extension.len()));
definition_extension.push((Arc::new(RwLock::new(cons_fun)), None)); definition_extension.push((Arc::new(RwLock::new(cons_fun)), None));
unifier.unify(constructor.unwrap(), signature).map_err(|e| HashSet::from([ unifier.unify(constructor.unwrap(), signature).map_err(|e| {
e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() HashSet::from([e
]))?; .at(Some(ast.as_ref().unwrap().location))
.to_display(unifier)
.to_string()])
})?;
return Ok(()); return Ok(());
} }
let mut init_id: Option<DefinitionId> = None; let mut init_id: Option<DefinitionId> = None;
@ -1605,7 +1607,8 @@ impl TopLevelComposer {
init_id = Some(*id); init_id = Some(*id);
let func_ty_enum = unifier.get_ty(*func_sig); let func_ty_enum = unifier.get_ty(*func_sig);
let TypeEnum::TFunc(FunSignature { args, vars, .. }) = let TypeEnum::TFunc(FunSignature { args, vars, .. }) =
func_ty_enum.as_ref() else { func_ty_enum.as_ref()
else {
unreachable!("must be typeenum::tfunc") unreachable!("must be typeenum::tfunc")
}; };
@ -1620,9 +1623,12 @@ impl TopLevelComposer {
ret: self_type, ret: self_type,
vars: contor_type_vars, vars: contor_type_vars,
})); }));
unifier.unify(constructor.unwrap(), contor_type).map_err(|e| HashSet::from([ unifier.unify(constructor.unwrap(), contor_type).map_err(|e| {
e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string() HashSet::from([e
]))?; .at(Some(ast.as_ref().unwrap().location))
.to_display(unifier)
.to_string()])
})?;
// class field instantiation check // class field instantiation check
if let (Some(init_id), false) = (init_id, fields.is_empty()) { if let (Some(init_id), false) = (init_id, fields.is_empty()) {
@ -1641,7 +1647,7 @@ impl TopLevelComposer {
class_name, class_name,
body[0].location, body[0].location,
), ),
])) ]));
} }
} }
} }
@ -1658,11 +1664,12 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
for (i, signature, id) in constructors { for (i, signature, id) in constructors {
let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() else { let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write()
else {
unreachable!() unreachable!()
}; };
@ -1697,8 +1704,8 @@ impl TopLevelComposer {
} = &mut *function_def } = &mut *function_def
{ {
let signature_ty_enum = unifier.get_ty(*signature); let signature_ty_enum = unifier.get_ty(*signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref()
signature_ty_enum.as_ref() else { else {
unreachable!("must be typeenum::tfunc") unreachable!("must be typeenum::tfunc")
}; };
@ -1714,10 +1721,7 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id); let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds( let self_ty = get_type_from_type_annotation_kinds(
&def_list, &def_list, unifier, &ty_ann, &mut None,
unifier,
&ty_ann,
&mut None
)?; )?;
vars.extend(type_vars.iter().map(|ty| { vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
@ -1739,7 +1743,9 @@ impl TopLevelComposer {
.values() .values()
.map(|ty| { .map(|ty| {
unifier.get_instantiations(*ty).unwrap_or_else(|| { unifier.get_instantiations(*ty).unwrap_or_else(|| {
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else { let TypeEnum::TVar { name, loc, is_const_generic: false, .. } =
&*unifier.get_ty(*ty)
else {
unreachable!() unreachable!()
}; };
@ -1779,8 +1785,7 @@ impl TopLevelComposer {
let class_ty_var_ids = type_vars let class_ty_var_ids = type_vars
.iter() .iter()
.map(|x| { .map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
{
*id *id
} else { } else {
unreachable!("must be type var here"); unreachable!("must be type var here");
@ -1839,7 +1844,8 @@ impl TopLevelComposer {
}; };
let ast::StmtKind::FunctionDef { body, decorator_list, .. } = let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node else { ast.clone().unwrap().node
else {
unreachable!("must be function def ast") unreachable!("must be function def ast")
}; };
if !decorator_list.is_empty() if !decorator_list.is_empty()
@ -1857,13 +1863,12 @@ impl TopLevelComposer {
continue; continue;
} }
let fun_body = body let fun_body = body
.into_iter() .into_iter()
.map(|b| inferencer.fold_stmt(b)) .map(|b| inferencer.fold_stmt(b))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let returned = let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
{ {
// check virtuals // check virtuals
let defs = ctx.definitions.read(); let defs = ctx.definitions.read();
@ -1873,9 +1878,9 @@ impl TopLevelComposer {
if let TypeEnum::TObj { obj_id, .. } = &*base { if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id *obj_id
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("Base type should be a class (at {loc})"), "Base type should be a class (at {loc})"
])) )]));
} }
}; };
let subtype_id = { let subtype_id = {
@ -1887,7 +1892,7 @@ impl TopLevelComposer {
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
])) ]));
} }
}; };
let subtype_entry = defs[subtype_id.0].read(); let subtype_entry = defs[subtype_id.0].read();
@ -1902,7 +1907,7 @@ impl TopLevelComposer {
let subtype_repr = inferencer.unifier.stringify(*subtype); let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!( return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"), "Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
])) ]));
} }
} }
} }
@ -1912,7 +1917,9 @@ impl TopLevelComposer {
inst_ret, inst_ret,
&mut |id| { &mut |id| {
let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read() let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read()
else { unreachable!("must be class id here") }; else {
unreachable!("must be class id here")
};
name.to_string() name.to_string()
}, },
@ -1924,11 +1931,16 @@ impl TopLevelComposer {
ret_str, ret_str,
name, name,
ast.as_ref().unwrap().location ast.as_ref().unwrap().location
),])) )]));
} }
instance_to_stmt.insert( instance_to_stmt.insert(
get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())), get_subst_key(
unifier,
self_type,
&subst,
Some(&vars.keys().copied().collect()),
),
FunInstance { FunInstance {
body: Arc::new(fun_body), body: Arc::new(fun_body),
unifier_id: 0, unifier_id: 0,
@ -1950,7 +1962,7 @@ impl TopLevelComposer {
} }
} }
if !errors.is_empty() { if !errors.is_empty() {
return Err(errors) return Err(errors);
} }
Ok(()) Ok(())
} }

View File

@ -4,75 +4,270 @@ use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys; use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Mapping, VarMap}; use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use super::*; use super::*;
/// Structure storing [`DefinitionId`] for primitive types. /// All primitive types and functions in nac3core.
#[derive(Clone, Copy)] #[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub struct PrimitiveDefinitionIds { pub enum PrimDef {
pub int32: DefinitionId, Int32,
pub int64: DefinitionId, Int64,
pub uint32: DefinitionId, Float,
pub uint64: DefinitionId, Bool,
pub float: DefinitionId, None,
pub bool: DefinitionId, Range,
pub none: DefinitionId, Str,
pub range: DefinitionId, Exception,
pub str: DefinitionId, UInt32,
pub exception: DefinitionId, UInt64,
pub option: DefinitionId, Option,
pub ndarray: DefinitionId, OptionIsSome,
OptionIsNone,
OptionUnwrap,
NDArray,
NDArrayCopy,
NDArrayFill,
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunNpNDArray,
FunNpEmpty,
FunNpZeros,
FunNpOnes,
FunNpFull,
FunNpArray,
FunNpEye,
FunNpIdentity,
FunRound,
FunRound64,
FunNpRound,
FunRange,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil,
FunLen,
FunMin,
FunNpMin,
FunNpMinimum,
FunMax,
FunNpMax,
FunNpMaximum,
FunAbs,
FunNpIsNan,
FunNpIsInf,
FunNpSin,
FunNpCos,
FunNpExp,
FunNpExp2,
FunNpLog,
FunNpLog10,
FunNpLog2,
FunNpFabs,
FunNpSqrt,
FunNpRint,
FunNpTan,
FunNpArcsin,
FunNpArccos,
FunNpArctan,
FunNpSinh,
FunNpCosh,
FunNpTanh,
FunNpArcsinh,
FunNpArccosh,
FunNpArctanh,
FunNpExpm1,
FunNpCbrt,
FunSpSpecErf,
FunSpSpecErfc,
FunSpSpecGamma,
FunSpSpecGammaln,
FunSpSpecJ0,
FunSpSpecJ1,
FunNpArctan2,
FunNpCopysign,
FunNpFmax,
FunNpFmin,
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunSome,
} }
impl PrimitiveDefinitionIds { /// Associated details of a [`PrimDef`]
/// Returns all [`DefinitionId`] of primitives as a [`Vec`]. pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str },
}
impl PrimDef {
/// Get the assigned [`DefinitionId`] of this [`PrimDef`].
/// ///
/// There are no guarantees on ordering of the IDs. /// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at,
/// with the first `PrimDef`'s definition id being `0`.
#[must_use] #[must_use]
fn as_vec(&self) -> Vec<DefinitionId> { pub fn id(&self) -> DefinitionId {
vec![ DefinitionId(*self as usize)
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
]
} }
/// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order. /// Check if a definition ID is that of a [`PrimDef`].
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> { #[must_use]
self.as_vec().into_iter() pub fn contains_id(id: DefinitionId) -> bool {
Self::iter().any(|prim| prim.id() == id)
} }
/// Returns the primitive with the largest [`DefinitionId`]. /// Get the definition "simple name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`].
///
/// If the [`PrimDef`] is a class, this returns [`None`].
#[must_use] #[must_use]
pub fn max_id(&self) -> DefinitionId { pub fn simple_name(&self) -> &'static str {
self.iter().max().unwrap() match self.details() {
PrimDefDetails::PrimFunction { simple_name, .. } => simple_name,
PrimDefDetails::PrimClass { .. } => {
panic!("PrimDef {self:?} has no simple_name as it is not a function.")
}
}
}
/// Get the definition "name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`].
///
/// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`].
#[must_use]
pub fn name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
}
}
/// Get the associated details of this [`PrimDef`]
#[must_use]
pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name }
}
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name }
}
match self {
PrimDef::Int32 => class("int32"),
PrimDef::Int64 => class("int64"),
PrimDef::Float => class("float"),
PrimDef::Bool => class("bool"),
PrimDef::None => class("none"),
PrimDef::Range => class("range"),
PrimDef::Str => class("str"),
PrimDef::Exception => class("Exception"),
PrimDef::UInt32 => class("uint32"),
PrimDef::UInt64 => class("uint64"),
PrimDef::Option => class("Option"),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::NDArray => class("ndarray"),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None),
PrimDef::FunNpOnes => fun("np_ones", None),
PrimDef::FunNpFull => fun("np_full", None),
PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRange => fun("range", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None),
PrimDef::FunNpCos => fun("np_cos", None),
PrimDef::FunNpExp => fun("np_exp", None),
PrimDef::FunNpExp2 => fun("np_exp2", None),
PrimDef::FunNpLog => fun("np_log", None),
PrimDef::FunNpLog10 => fun("np_log10", None),
PrimDef::FunNpLog2 => fun("np_log2", None),
PrimDef::FunNpFabs => fun("np_fabs", None),
PrimDef::FunNpSqrt => fun("np_sqrt", None),
PrimDef::FunNpRint => fun("np_rint", None),
PrimDef::FunNpTan => fun("np_tan", None),
PrimDef::FunNpArcsin => fun("np_arcsin", None),
PrimDef::FunNpArccos => fun("np_arccos", None),
PrimDef::FunNpArctan => fun("np_arctan", None),
PrimDef::FunNpSinh => fun("np_sinh", None),
PrimDef::FunNpCosh => fun("np_cosh", None),
PrimDef::FunNpTanh => fun("np_tanh", None),
PrimDef::FunNpArcsinh => fun("np_arcsinh", None),
PrimDef::FunNpArccosh => fun("np_arccosh", None),
PrimDef::FunNpArctanh => fun("np_arctanh", None),
PrimDef::FunNpExpm1 => fun("np_expm1", None),
PrimDef::FunNpCbrt => fun("np_cbrt", None),
PrimDef::FunSpSpecErf => fun("sp_spec_erf", None),
PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None),
PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None),
PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None),
PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None),
PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None),
PrimDef::FunNpArctan2 => fun("np_arctan2", None),
PrimDef::FunNpCopysign => fun("np_copysign", None),
PrimDef::FunNpFmax => fun("np_fmax", None),
PrimDef::FunNpFmin => fun("np_fmin", None),
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
PrimDef::FunSome => fun("Some", None),
}
} }
} }
/// The [definition IDs][DefinitionId] for primitive types. /// Asserts that a [`PrimDef`] is in an allowlist.
pub const PRIMITIVE_DEF_IDS: PrimitiveDefinitionIds = PrimitiveDefinitionIds { ///
int32: DefinitionId(0), /// Like `debug_assert!`, this statements of this function are only
int64: DefinitionId(1), /// enabled if `cfg!(debug_assertions)` is true.
uint32: DefinitionId(8), pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) {
uint64: DefinitionId(9), if cfg!(debug_assertions) {
float: DefinitionId(2), let allowed = allowlist.iter().any(|p| *p == prim);
bool: DefinitionId(3), assert!(
none: DefinitionId(4), allowed,
range: DefinitionId(5), "Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}"
str: DefinitionId(6), );
exception: DefinitionId(7), }
option: DefinitionId(10), }
ndarray: DefinitionId(14),
};
impl TopLevelDef { impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String { pub fn to_string(&self, unifier: &mut Unifier) -> String {
@ -116,42 +311,42 @@ impl TopLevelComposer {
pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) { pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32, obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64, obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float, obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool, obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none, obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range, obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str, obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception, obj_id: PrimDef::Exception.id(),
fields: vec![ fields: vec![
("__name__".into(), (int32, true)), ("__name__".into(), (int32, true)),
("__file__".into(), (str, true)), ("__file__".into(), (str, true)),
@ -168,12 +363,12 @@ impl TopLevelComposer {
params: VarMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32, obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64, obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
@ -190,7 +385,7 @@ impl TopLevelComposer {
vars: VarMap::from([(option_type_var.1, option_type_var.0)]), vars: VarMap::from([(option_type_var.1, option_type_var.0)]),
})); }));
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option, obj_id: PrimDef::Option.id(),
fields: vec![ fields: vec![
("is_some".into(), (is_some_type_fun_ty, true)), ("is_some".into(), (is_some_type_fun_ty, true)),
("is_none".into(), (is_some_type_fun_ty, true)), ("is_none".into(), (is_some_type_fun_ty, true)),
@ -208,7 +403,8 @@ impl TopLevelComposer {
}; };
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None); let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![], args: vec![],
@ -219,13 +415,11 @@ impl TopLevelComposer {
]), ]),
})); }));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![ args: vec![FuncArg {
FuncArg { name: "value".into(),
name: "value".into(), ty: ndarray_dtype_tvar.0,
ty: ndarray_dtype_tvar.0, default_value: None,
default_value: None, }],
},
],
ret: none, ret: none,
vars: VarMap::from([ vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -233,7 +427,7 @@ impl TopLevelComposer {
]), ]),
})); }));
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray, obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([ fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)), ("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)),
@ -393,9 +587,7 @@ impl TopLevelComposer {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() { if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id) Ok(*id)
} else { } else {
Err(HashSet::from([ Err(HashSet::from(["not type var".to_string()]))
"not type var".to_string(),
]))
} }
} }
@ -412,25 +604,27 @@ impl TopLevelComposer {
let ( let (
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }), TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }), TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other) else { ) = (this, other)
else {
unreachable!("this function must be called with function type") unreachable!("this function must be called with function type")
}; };
// check args // check args
let args_ok = this_args let args_ok =
.iter() this_args
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap())) .iter()
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| { .map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
(name, type_var_to_concrete_def.get(ty).unwrap()) .zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
})) (name, type_var_to_concrete_def.get(ty).unwrap())
.all(|(this, other)| { }))
if this.0 == &"self".into() && this.0 == other.0 { .all(|(this, other)| {
true if this.0 == &"self".into() && this.0 == other.0 {
} else { true
this.0 == other.0 } else {
&& check_overload_type_annotation_compatible(this.1, other.1, unifier) this.0 == other.0
} && check_overload_type_annotation_compatible(this.1, other.1, unifier)
}); }
});
// check rets // check rets
let ret_ok = check_overload_type_annotation_compatible( let ret_ok = check_overload_type_annotation_compatible(
@ -473,12 +667,10 @@ impl TopLevelComposer {
} }
} => } =>
{ {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "redundant type annotation for class fields at {}",
"redundant type annotation for class fields at {}", s.location
s.location )]))
),
]))
} }
ast::StmtKind::Assign { targets, .. } => { ast::StmtKind::Assign { targets, .. } => {
for t in targets { for t in targets {
@ -602,112 +794,109 @@ pub fn parse_parameter_default_value(
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?, tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)), )),
Constant::None => Err(HashSet::from([ Constant::None => Err(HashSet::from([format!(
format!( "`None` is not supported, use `none` for option type instead ({loc})"
"`None` is not supported, use `none` for option type instead ({loc})" )])),
),
])),
_ => unimplemented!("this constant is not supported at {}", loc), _ => unimplemented!("this constant is not supported at {}", loc),
} }
} }
match &default.node { match &default.node {
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location), ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => { ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node {
match &func.node { ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node { ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { let v: Result<i64, _> = (*v).try_into();
let v: Result<i64, _> = (*v).try_into(); match v {
match v { Ok(v) => Ok(SymbolValue::I64(v)),
Ok(v) => Ok(SymbolValue::I64(v)), _ => Err(HashSet::from([format!(
_ => Err(HashSet::from([ "default param value out of range at {}",
format!("default param value out of range at {}", default.location) default.location
])), )])),
}
} }
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
} }
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node { _ => Err(HashSet::from([format!(
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { "only allow constant integer here at {}",
let v: Result<u32, _> = (*v).try_into(); default.location
match v { )])),
Ok(v) => Ok(SymbolValue::U32(v)), },
_ => Err(HashSet::from([ ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
format!("default param value out of range at {}", default.location), ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
])), let v: Result<u32, _> = (*v).try_into();
} match v {
Ok(v) => Ok(SymbolValue::U32(v)),
_ => Err(HashSet::from([format!(
"default param value out of range at {}",
default.location
)])),
} }
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
} }
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node { _ => Err(HashSet::from([format!(
ast::ExprKind::Constant { value: Constant::Int(v), .. } => { "only allow constant integer here at {}",
let v: Result<u64, _> = (*v).try_into(); default.location
match v { )])),
Ok(v) => Ok(SymbolValue::U64(v)), },
_ => Err(HashSet::from([ ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
format!("default param value out of range at {}", default.location), ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
])), let v: Result<u64, _> = (*v).try_into();
} match v {
Ok(v) => Ok(SymbolValue::U64(v)),
_ => Err(HashSet::from([format!(
"default param value out of range at {}",
default.location
)])),
} }
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
} }
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok( _ => Err(HashSet::from([format!(
SymbolValue::OptionSome( "only allow constant integer here at {}",
Box::new(parse_parameter_default_value(&args[0], resolver)?) default.location
) )])),
), },
_ => Err(HashSet::from([ ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome(
format!("unsupported default parameter at {}", default.location), Box::new(parse_parameter_default_value(&args[0], resolver)?),
])), )),
} _ => Err(HashSet::from([format!(
} "unsupported default parameter at {}",
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts default.location
.iter() )])),
.map(|x| parse_parameter_default_value(x, resolver)) },
.collect::<Result<Vec<_>, _>>()? ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(
elts.iter()
.map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()?,
)), )),
ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone),
ast::ExprKind::Name { id, .. } => { ast::ExprKind::Name { id, .. } => {
resolver.get_default_param_value(default).ok_or_else( resolver.get_default_param_value(default).ok_or_else(|| {
|| HashSet::from([ HashSet::from([format!(
format!( "`{}` cannot be used as a default parameter at {} \
"`{}` cannot be used as a default parameter at {} \
(not primitive type, option or tuple / not defined?)", (not primitive type, option or tuple / not defined?)",
id, id, default.location
default.location )])
), })
])
)
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!( "unsupported default parameter (not primitive type, option or tuple) at {}",
"unsupported default parameter (not primitive type, option or tuple) at {}", default.location
default.location )])),
),
]))
} }
} }
/// Obtains the element type of an array-like type. /// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type { pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0, unpack_ndarray_var_tys(unifier, ty).0
}
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty), TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
_ => ty _ => ty,
} }
} }
/// Obtains the number of dimensions of an array-like type. /// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 { pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) { match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1; let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else { let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims)) panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
@ -721,6 +910,6 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
} }
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1, TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
_ => 0 _ => 0,
} }
} }

View File

@ -8,7 +8,9 @@ use std::{
use super::codegen::CodeGenContext; use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore; use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap}; use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
};
use crate::{ use crate::{
codegen::CodeGenerator, codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum}, symbol_resolver::{SymbolResolver, ValueEnum},
@ -32,16 +34,15 @@ use type_annotation::*;
#[cfg(test)] #[cfg(test)]
mod test; mod test;
type GenCallCallback = type GenCallCallback = dyn for<'ctx, 'a> Fn(
dyn for<'ctx, 'a> Fn( &mut CodeGenContext<'ctx, 'a>,
&mut CodeGenContext<'ctx, 'a>, Option<(Type, ValueEnum<'ctx>)>,
Option<(Type, ValueEnum<'ctx>)>, (&FunSignature, DefinitionId),
(&FunSignature, DefinitionId), Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
Vec<(Option<StrRef>, ValueEnum<'ctx>)>, &mut dyn CodeGenerator,
&mut dyn CodeGenerator, ) -> Result<Option<BasicValueEnum<'ctx>>, String>
) -> Result<Option<BasicValueEnum<'ctx>>, String> + Send
+ Send + Sync;
+ Sync;
pub struct GenCall { pub struct GenCall {
fp: Box<GenCallCallback>, fp: Box<GenCallCallback>,
@ -53,7 +54,7 @@ impl GenCall {
GenCall { fp } GenCall { fp }
} }
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given /// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`. /// `reason`.
#[must_use] #[must_use]
pub fn create_dummy(reason: String) -> GenCall { pub fn create_dummy(reason: String) -> GenCall {

View File

@ -1,14 +1,14 @@
use itertools::Itertools;
use crate::{ use crate::{
toplevel::helper::PRIMITIVE_DEF_IDS, toplevel::helper::PrimDef,
typecheck::{ typecheck::{
type_inferencer::PrimitiveStore, type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap}, typedef::{Type, TypeEnum, Unifier, VarMap},
}, },
}; };
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments. /// Creates a `ndarray` [`Type`] with the given type arguments.
/// ///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not /// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized. /// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not /// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
@ -37,15 +37,13 @@ pub fn subst_ndarray_tvars(
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() { if dtype.is_none() && ndims.is_none() {
return ndarray return ndarray;
} }
let tvar_ids = params.iter() let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
.map(|(obj_id, _)| *obj_id)
.collect_vec();
debug_assert_eq!(tvar_ids.len(), 2); debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new(); let mut tvar_subst = VarMap::new();
@ -59,45 +57,29 @@ pub fn subst_ndarray_tvars(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
} }
fn unpack_ndarray_tvars( fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)> {
unifier: &mut Unifier,
ndarray: Type,
) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
}; };
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2); debug_assert_eq!(params.len(), 2);
params.iter() params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id) .sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty)) .map(|(var_id, ty)| (*var_id, *ty))
.collect_vec() .collect_vec()
} }
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds /// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` /// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively. /// respectively.
pub fn unpack_ndarray_var_ids( pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (u32, u32) {
unifier: &mut Unifier, unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
} }
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to /// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. /// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys( pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unifier: &mut Unifier, unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple()
.unwrap()
} }

View File

@ -65,7 +65,11 @@ impl SymbolResolver for Resolver {
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).cloned() self.0
.id_to_def
.lock()
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }

View File

@ -1,7 +1,7 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::typecheck::typedef::VarMap;
use super::*; use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant; use nac3parser::ast::Constant;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -29,9 +29,7 @@ impl TypeAnnotation {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty), Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => { CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level { let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } = if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
&*top.definitions.read()[id.0].read()
{
(*name).into() (*name).into()
} else { } else {
unreachable!() unreachable!()
@ -39,24 +37,26 @@ impl TypeAnnotation {
} else { } else {
format!("class_def_{}", id.0) format!("class_def_{}", id.0)
}; };
format!( format!("{}{}", class_name, {
"{}{}", let param_list =
class_name, params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
{ if param_list.is_empty() {
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "); String::new()
if param_list.is_empty() { } else {
String::new() format!("[{param_list}]")
} else {
format!("[{param_list}]")
}
} }
) })
}
Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
} }
Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)), Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)), List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => { Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")) format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
} }
} }
} }
@ -68,18 +68,18 @@ impl TypeAnnotation {
/// generic variables associated with the definition. /// generic variables associated with the definition.
/// * `type_var` - The type variable associated with the type argument currently being parsed. Pass /// * `type_var` - The type variable associated with the type argument currently being parsed. Pass
/// [`None`] when this function is invoked externally. /// [`None`] when this function is invoked externally.
pub fn parse_ast_to_type_annotation_kinds<T>( pub fn parse_ast_to_type_annotation_kinds<T, S: std::hash::BuildHasher + Clone>(
resolver: &(dyn SymbolResolver + Send + Sync), resolver: &(dyn SymbolResolver + Send + Sync),
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
expr: &ast::Expr<T>, expr: &ast::Expr<T>,
// the key stores the type_var of this topleveldef::class, we only need this field here // the key stores the type_var of this topleveldef::class, we only need this field here
locked: HashMap<DefinitionId, Vec<Type>>, locked: HashMap<DefinitionId, Vec<Type>, S>,
) -> Result<TypeAnnotation, HashSet<String>> { ) -> Result<TypeAnnotation, HashSet<String>> {
let name_handle = |id: &StrRef, let name_handle = |id: &StrRef,
unifier: &mut Unifier, unifier: &mut Unifier,
locked: HashMap<DefinitionId, Vec<Type>>| { locked: HashMap<DefinitionId, Vec<Type>, S>| {
if id == &"int32".into() { if id == &"int32".into() {
Ok(TypeAnnotation::Primitive(primitives.int32)) Ok(TypeAnnotation::Primitive(primitives.int32))
} else if id == &"int64".into() { } else if id == &"int64".into() {
@ -95,7 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else if id == &"str".into() { } else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str)) Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() { } else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: PRIMITIVE_DEF_IDS.exception, params: Vec::default() }) Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) { } else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = { let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read(); let def_read = top_level_defs[obj_id.0].try_read();
@ -103,12 +103,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let TopLevelDef::Class { type_vars, .. } = &*def_read { if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone() type_vars.clone()
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "function cannot be used as a type (at {})",
"function cannot be used as a type (at {})", expr.location
expr.location )]));
),
]))
} }
} else { } else {
locked.get(&obj_id).unwrap().clone() locked.get(&obj_id).unwrap().clone()
@ -116,13 +114,11 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
// check param number here // check param number here
if !type_vars.is_empty() { if !type_vars.is_empty() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "expect {} type variable parameter but got 0 (at {})",
"expect {} type variable parameter but got 0 (at {})", type_vars.len(),
type_vars.len(), expr.location,
expr.location, )]));
),
]))
} }
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) { } else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
@ -131,14 +127,16 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
unifier.unify(var, ty).unwrap(); unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty)) Ok(TypeAnnotation::TypeVar(ty))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("`{}` is not a valid type annotation (at {})", id, expr.location), "`{}` is not a valid type annotation (at {})",
])) id, expr.location
)]))
} }
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("`{}` is not a valid type annotation (at {})", id, expr.location), "`{}` is not a valid type annotation (at {})",
])) id, expr.location
)]))
} }
}; };
@ -146,12 +144,14 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
|id: &StrRef, |id: &StrRef,
slice: &ast::Expr<T>, slice: &ast::Expr<T>,
unifier: &mut Unifier, unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| { mut locked: HashMap<DefinitionId, Vec<Type>, S>| {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id) if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()]
.contains(id)
{ {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("keywords cannot be class name (at {})", expr.location), "keywords cannot be class name (at {})",
])) expr.location
)]));
} }
let obj_id = resolver.get_identifier_def(*id)?; let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = { let type_vars = {
@ -174,14 +174,12 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
vec![slice] vec![slice]
}; };
if type_vars.len() != params_ast.len() { if type_vars.len() != params_ast.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "expect {} type parameters but got {} (at {})",
"expect {} type parameters but got {} (at {})", type_vars.len(),
type_vars.len(), params_ast.len(),
params_ast.len(), params_ast[0].location,
params_ast[0].location, )]));
),
]))
} }
let result = params_ast let result = params_ast
.iter() .iter()
@ -210,7 +208,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
"application of type vars to generic class is not currently supported (at {})", "application of type vars to generic class is not currently supported (at {})",
params_ast[0].location params_ast[0].location
), ),
])) ]));
} }
}; };
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos }) Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
@ -309,9 +307,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// Literal // Literal
ast::ExprKind::Subscript { value, slice, .. } ast::ExprKind::Subscript { value, slice, .. }
if { if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into()) matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} => { } =>
{
let tup_elts = { let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node { if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice() elts.as_slice()
@ -321,20 +320,18 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
}; };
let type_annotations = tup_elts let type_annotations = tup_elts
.iter() .iter()
.map(|e| { .map(|e| match &e.node {
match &e.node { ast::ExprKind::Constant { value, .. } => {
ast::ExprKind::Constant { value, .. } => Ok( Ok(TypeAnnotation::Literal(vec![value.clone()]))
TypeAnnotation::Literal(vec![value.clone()]),
),
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
} }
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
}) })
.collect::<Result<Vec<_>, _>>()? .collect::<Result<Vec<_>, _>>()?
.into_iter() .into_iter()
@ -347,9 +344,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if type_annotations.len() == 1 { if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations)) Ok(TypeAnnotation::Literal(type_annotations))
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("multiple literal bounds are currently unsupported (at {})", value.location) "multiple literal bounds are currently unsupported (at {})",
])) value.location
)]))
} }
} }
@ -358,19 +356,19 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let ast::ExprKind::Name { id, .. } = &value.node { if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked) class_name_handle(id, slice, unifier, locked)
} else { } else {
Err(HashSet::from([ Err(HashSet::from([format!(
format!("unsupported expression type for class name (at {})", value.location) "unsupported expression type for class name (at {})",
])) value.location
)]))
} }
} }
ast::ExprKind::Constant { value, .. } => { ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!("unsupported expression for type annotation (at {})", expr.location), "unsupported expression for type annotation (at {})",
])), expr.location
)])),
} }
} }
@ -381,7 +379,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>], top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier, unifier: &mut Unifier,
ann: &TypeAnnotation, ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>> subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
match ann { match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => { TypeAnnotation::CustomClass { id: obj_id, params } => {
@ -392,24 +390,17 @@ pub fn get_type_from_type_annotation_kinds(
}; };
if type_vars.len() != params.len() { if type_vars.len() != params.len() {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "unexpected number of type parameters: expected {} but got {}",
"unexpected number of type parameters: expected {} but got {}", type_vars.len(),
type_vars.len(), params.len()
params.len() )]));
),
]))
} }
let param_ty = params let param_ty = params
.iter() .iter()
.map(|x| { .map(|x| {
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
top_level_defs,
unifier,
x,
subst_list
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
@ -419,7 +410,14 @@ pub fn get_type_from_type_annotation_kinds(
let mut result = VarMap::new(); let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) { for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() { match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => { TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
@ -434,18 +432,16 @@ pub fn get_type_from_type_annotation_kinds(
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "cannot apply type {} to type variable with id {:?}",
"cannot apply type {} to type variable with id {:?}", unifier.internal_stringify(
unifier.internal_stringify( p,
p, &mut |id| format!("class{id}"),
&mut |id| format!("class{id}"), &mut |id| format!("typevar{id}"),
&mut |id| format!("typevar{id}"), &mut None
&mut None ),
), *id
*id )]));
)
]))
} }
} }
@ -454,24 +450,18 @@ pub fn get_type_from_type_annotation_kinds(
let ok: bool = { let ok: bool = {
// create a temp type var and unify to check compatibility // create a temp type var and unify to check compatibility
p == *tvar || { p == *tvar || {
let temp = unifier.get_fresh_const_generic_var( let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
ty,
*name,
*loc,
);
unifier.unify(temp.0, p).is_ok() unifier.unify(temp.0, p).is_ok()
} }
}; };
if ok { if ok {
result.insert(*id, p); result.insert(*id, p);
} else { } else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "cannot apply type {} to type variable {}",
"cannot apply type {} to type variable {}", unifier.stringify(p),
unifier.stringify(p), name.unwrap_or_else(|| format!("typevar{id}").into()),
name.unwrap_or_else(|| format!("typevar{id}").into()), )]));
),
]))
} }
} }
@ -507,7 +497,8 @@ pub fn get_type_from_type_annotation_kinds(
} }
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty), TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => { TypeAnnotation::Literal(values) => {
let values = values.iter() let values = values
.iter()
.map(SymbolValue::from_constant_inferred) .map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>() .collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?; .map_err(|err| HashSet::from([err]))?;
@ -520,7 +511,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
ty.as_ref(), ty.as_ref(),
subst_list subst_list,
)?; )?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty })) Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
} }
@ -529,7 +520,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs, top_level_defs,
unifier, unifier,
ty.as_ref(), ty.as_ref(),
subst_list subst_list,
)?; )?;
Ok(unifier.add_ty(TypeEnum::TList { ty })) Ok(unifier.add_ty(TypeEnum::TList { ty }))
} }
@ -607,7 +598,8 @@ pub fn check_overload_type_annotation_compatible(
let ( let (
TypeEnum::TVar { id: a, fields: None, .. }, TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. }, TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b) else { ) = (a, b)
else {
unreachable!("must be type var") unreachable!("must be type var")
}; };

View File

@ -2,15 +2,17 @@ use crate::typecheck::typedef::TypeEnum;
use super::type_inferencer::Inferencer; use super::type_inferencer::Inferencer;
use super::typedef::Type; use super::typedef::Type;
use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef}; use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use std::{collections::HashSet, iter::once}; use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> { impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> { fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) { if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(HashSet::from([ Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
format!("Error at {}: cannot have value none", expr.location),
]))
} else { } else {
Ok(()) Ok(())
} }
@ -22,9 +24,9 @@ impl<'a> Inferencer<'a> {
defined_identifiers: &mut HashSet<StrRef>, defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
match &pattern.node { match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([ ExprKind::Name { id, .. } if id == &"none".into() => {
format!("cannot assign to a `none` (at {})", pattern.location), Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
])), }
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) { if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id); defined_identifiers.insert(*id);
@ -44,20 +46,17 @@ impl<'a> Inferencer<'a> {
self.should_have_value(value)?; self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?; self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) { if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Error at {}: cannot assign to tuple element",
"Error at {}: cannot assign to tuple element", value.location
value.location )]));
),
]))
} }
Ok(()) Ok(())
} }
ExprKind::Constant { .. } => { ExprKind::Constant { .. } => Err(HashSet::from([format!(
Err(HashSet::from([ "cannot assign to a constant (at {})",
format!("cannot assign to a constant (at {})", pattern.location), pattern.location
])) )])),
}
_ => self.check_expr(pattern, defined_identifiers), _ => self.check_expr(pattern, defined_identifiers),
} }
} }
@ -69,14 +68,14 @@ impl<'a> Inferencer<'a> {
) -> Result<(), HashSet<String>> { ) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None // there are some cases where the custom field is None
if let Some(ty) = &expr.custom { if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) { if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
return Err(HashSet::from([ && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
format!( {
"expected concrete type at {} but got {}", return Err(HashSet::from([format!(
expr.location, "expected concrete type at {} but got {}",
self.unifier.get_ty(*ty).get_type_name() expr.location,
) self.unifier.get_ty(*ty).get_type_name()
])) )]));
} }
} }
match &expr.node { match &expr.node {
@ -96,12 +95,10 @@ impl<'a> Inferencer<'a> {
self.defined_identifiers.insert(*id); self.defined_identifiers.insert(*id);
} }
Err(e) => { Err(e) => {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "type error at identifier `{}` ({}) at {}",
"type error at identifier `{}` ({}) at {}", id, e, expr.location
id, e, expr.location )]))
)
]))
} }
} }
} }
@ -127,17 +124,13 @@ impl<'a> Inferencer<'a> {
// Check whether a bitwise shift has a negative RHS constant value // Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift { if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node { if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else { let Constant::Int(rhs_val) = value else { unreachable!() };
unreachable!()
};
if *rhs_val < 0 { if *rhs_val < 0 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "shift count is negative at {}",
"shift count is negative at {}", right.location
right.location )]));
),
]))
} }
} }
} }
@ -214,16 +207,16 @@ impl<'a> Inferencer<'a> {
/// is freed when the function returns. /// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool { fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
match &*self.unifier.get_ty_immutable(ret_ty) { match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => { TypeEnum::TObj { .. } => [
[ self.primitives.int32,
self.primitives.int32, self.primitives.int64,
self.primitives.int64, self.primitives.uint32,
self.primitives.uint32, self.primitives.uint64,
self.primitives.uint64, self.primitives.float,
self.primitives.float, self.primitives.bool,
self.primitives.bool, ]
].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)) .iter()
} .any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)), TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false, _ => false,
} }
@ -330,8 +323,11 @@ impl<'a> Inferencer<'a> {
if let Some(ret_ty) = value.custom { if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually // Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion // inferred and just generates an unconditional assertion
if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) { if matches!(
return Ok(true) value.node,
ExprKind::Constant { value: Constant::Ellipsis, .. }
) {
return Ok(true);
} }
if !self.check_return_value_ty(ret_ty) { if !self.check_return_value_ty(ret_ty) {
@ -341,7 +337,7 @@ impl<'a> Inferencer<'a> {
self.unifier.stringify(ret_ty), self.unifier.stringify(ret_ty),
value.location, value.location,
), ),
])) ]));
} }
} }
} }

View File

@ -1,19 +1,20 @@
use std::cmp::max;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS; use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys}; use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{ use crate::typecheck::{
type_inferencer::*, type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
}; };
use itertools::Itertools;
use nac3parser::ast::StrRef; use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop}; use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
use std::collections::HashMap; use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use itertools::Itertools; use strum::IntoEnumIterator;
#[must_use] #[must_use]
pub fn binop_name(op: &Operator) -> &'static str { pub fn binop_name(op: Operator) -> &'static str {
match op { match op {
Operator::Add => "__add__", Operator::Add => "__add__",
Operator::Sub => "__sub__", Operator::Sub => "__sub__",
@ -32,7 +33,7 @@ pub fn binop_name(op: &Operator) -> &'static str {
} }
#[must_use] #[must_use]
pub fn binop_assign_name(op: &Operator) -> &'static str { pub fn binop_assign_name(op: Operator) -> &'static str {
match op { match op {
Operator::Add => "__iadd__", Operator::Add => "__iadd__",
Operator::Sub => "__isub__", Operator::Sub => "__isub__",
@ -51,7 +52,7 @@ pub fn binop_assign_name(op: &Operator) -> &'static str {
} }
#[must_use] #[must_use]
pub fn unaryop_name(op: &Unaryop) -> &'static str { pub fn unaryop_name(op: Unaryop) -> &'static str {
match op { match op {
Unaryop::UAdd => "__pos__", Unaryop::UAdd => "__pos__",
Unaryop::USub => "__neg__", Unaryop::USub => "__neg__",
@ -61,7 +62,7 @@ pub fn unaryop_name(op: &Unaryop) -> &'static str {
} }
#[must_use] #[must_use]
pub fn comparison_name(op: &Cmpop) -> Option<&'static str> { pub fn comparison_name(op: Cmpop) -> Option<&'static str> {
match op { match op {
Cmpop::Lt => Some("__lt__"), Cmpop::Lt => Some("__lt__"),
Cmpop::LtE => Some("__le__"), Cmpop::LtE => Some("__le__"),
@ -115,7 +116,7 @@ pub fn impl_binop(
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0); let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops { for op in ops {
fields.insert(binop_name(op).into(), { fields.insert(binop_name(*op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -130,7 +131,7 @@ pub fn impl_binop(
) )
}); });
fields.insert(binop_assign_name(op).into(), { fields.insert(binop_assign_name(*op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -154,7 +155,7 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
for op in ops { for op in ops {
fields.insert( fields.insert(
unaryop_name(op).into(), unaryop_name(*op).into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -194,7 +195,7 @@ pub fn impl_cmpop(
for op in ops { for op in ops {
fields.insert( fields.insert(
comparison_name(op).unwrap().into(), comparison_name(*op).unwrap().into(),
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: ret_ty, ret: ret_ty,
@ -255,7 +256,14 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
/// `LShift`, `RShift` /// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]); impl_binop(
unifier,
store,
ty,
&[store.int32, store.uint32],
Some(ty),
&[Operator::LShift, Operator::RShift],
);
} }
/// `Div` /// `Div`
@ -297,7 +305,7 @@ pub fn impl_matmul(
store: &PrimitiveStore, store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Option<Type>, ret_ty: Option<Type>,
) { ) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]); impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
} }
@ -353,8 +361,8 @@ pub fn typeof_ndarray_broadcast(
left: Type, left: Type,
right: Type, right: Type,
) -> Result<Type, String> { ) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
assert!(is_left_ndarray || is_right_ndarray); assert!(is_left_ndarray || is_right_ndarray);
@ -375,7 +383,8 @@ pub fn typeof_ndarray_broadcast(
_ => unreachable!(), _ => unreachable!(),
}; };
let res_ndims = left_ty_ndims.into_iter() let res_ndims = left_ty_ndims
.into_iter()
.cartesian_product(right_ty_ndims) .cartesian_product(right_ty_ndims)
.map(|(left, right)| { .map(|(left, right)| {
let left_val = u64::try_from(left).unwrap(); let left_val = u64::try_from(left).unwrap();
@ -390,11 +399,7 @@ pub fn typeof_ndarray_broadcast(
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims))) Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else { } else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray { let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
(left, right)
} else {
(right, left)
};
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty); let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
@ -420,25 +425,21 @@ pub fn typeof_ndarray_broadcast(
pub fn typeof_binop( pub fn typeof_binop(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
op: &Operator, op: Operator,
lhs: Type, lhs: Type,
rhs: Type, rhs: Type,
) -> Result<Option<Type>, String> { ) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op { Ok(Some(match op {
Operator::Add Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
| Operator::Sub
| Operator::Mult
| Operator::Mod
| Operator::FloorDiv => {
if is_left_ndarray || is_right_ndarray { if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
lhs lhs
} else { } else {
return Ok(None) return Ok(None);
} }
} }
@ -464,12 +465,14 @@ pub fn typeof_binop(
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?, (2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => { (lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!( return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})", "Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
(rhs == 0) as u8 u8::from(rhs == 0)
)) ))
} }
(lhs, rhs) => { (lhs, rhs) => {
return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported")) return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
} }
} }
} }
@ -480,29 +483,35 @@ pub fn typeof_binop(
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
primitives.float primitives.float
} else { } else {
return Ok(None) return Ok(None);
} }
} }
Operator::Pow => { Operator::Pow => {
if is_left_ndarray || is_right_ndarray { if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) { } else if [
primitives.int32,
primitives.int64,
primitives.uint32,
primitives.uint64,
primitives.float,
]
.into_iter()
.any(|ty| unifier.unioned(lhs, ty))
{
lhs lhs
} else { } else {
return Ok(None) return Ok(None);
} }
} }
Operator::LShift Operator::LShift | Operator::RShift => lhs,
| Operator::RShift => lhs, Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
Operator::BitOr
| Operator::BitXor
| Operator::BitAnd => {
if unifier.unioned(lhs, rhs) { if unifier.unioned(lhs, rhs) {
lhs lhs
} else { } else {
return Ok(None) return Ok(None);
} }
} }
})) }))
@ -511,50 +520,51 @@ pub fn typeof_binop(
pub fn typeof_unaryop( pub fn typeof_unaryop(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
op: &Unaryop, op: Unaryop,
operand: Type, operand: Type,
) -> Result<Option<Type>, String> { ) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier); let operand_obj_id = operand.obj_id(unifier);
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) { if op == Unaryop::Not
return Err("The truth value of an array with more than one element is ambiguous".to_string()) && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
);
} }
Ok(match *op { Ok(match op {
Unaryop::Not => { Unaryop::Not => match operand_obj_id {
match operand_obj_id { Some(v) if v == PrimDef::NDArray.id() => Some(operand),
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand), Some(_) => Some(primitives.bool),
Some(_) => Some(primitives.bool), _ => None,
_ => None },
}
}
Unaryop::Invert => { Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32) Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand) Some(operand)
} else { } else {
None None
} }
} }
Unaryop::UAdd Unaryop::UAdd | Unaryop::USub => {
| Unaryop::USub => { if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand); let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if *op == Unaryop::UAdd { return Err(if op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string() "The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else { } else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string() "The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
}) });
} }
Some(operand) Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) { } else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32) Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { } else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand) Some(operand)
} else { } else {
None None
@ -567,16 +577,12 @@ pub fn typeof_unaryop(
pub fn typeof_cmpop( pub fn typeof_cmpop(
unifier: &mut Unifier, unifier: &mut Unifier,
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
_op: &Cmpop, _op: Cmpop,
lhs: Type, lhs: Type,
rhs: Type, rhs: Type,
) -> Result<Option<Type>, String> { ) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
.obj_id(unifier) let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(if is_left_ndarray || is_right_ndarray { Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?; let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
@ -586,7 +592,7 @@ pub fn typeof_cmpop(
} else if unifier.unioned(lhs, rhs) { } else if unifier.unioned(lhs, rhs) {
primitives.bool primitives.bool
} else { } else {
return Ok(None) return Ok(None);
})) }))
} }
@ -643,11 +649,19 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */ /* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); let ndarray_usized_ndims_tvar =
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);

View File

@ -89,10 +89,7 @@ impl<'a> Display for DisplayTypeError<'a> {
IncorrectArgType { name, expected, got } => { IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes); let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes); let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!( write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
f,
"Incorrect argument type for {name}. Expected {expected}, but got {got}"
)
} }
FieldUnificationError { field, types, loc } => { FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes); let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);

File diff suppressed because it is too large Load Diff

View File

@ -3,12 +3,12 @@ use super::*;
use crate::{ use crate::{
codegen::CodeGenContext, codegen::CodeGenContext,
symbol_resolver::ValueEnum, symbol_resolver::ValueEnum,
toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS, TopLevelDef}, toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
}; };
use indoc::indoc; use indoc::indoc;
use std::iter::zip;
use nac3parser::parser::parse_program; use nac3parser::parser::parse_program;
use parking_lot::RwLock; use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case; use test_case::test_case;
struct Resolver { struct Resolver {
@ -44,7 +44,9 @@ impl SymbolResolver for Resolver {
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def.get(&id).cloned() self.id_to_def
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()])) .ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
} }
@ -73,7 +75,7 @@ impl TestEnvironment {
let mut unifier = Unifier::new(); let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32, obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
@ -86,59 +88,60 @@ impl TestEnvironment {
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64, obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float, obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool, obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none, obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range, obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str, obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception, obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32, obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64, obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option, obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray, obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::from([ params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -211,7 +214,7 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new(); let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new(); let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj { let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32, obj_id: PrimDef::Int32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
@ -224,57 +227,57 @@ impl TestEnvironment {
fields.insert("__add__".into(), (add_ty, false)); fields.insert("__add__".into(), (add_ty, false));
}); });
let int64 = unifier.add_ty(TypeEnum::TObj { let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64, obj_id: PrimDef::Int64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let float = unifier.add_ty(TypeEnum::TObj { let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float, obj_id: PrimDef::Float.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let bool = unifier.add_ty(TypeEnum::TObj { let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool, obj_id: PrimDef::Bool.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let none = unifier.add_ty(TypeEnum::TObj { let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none, obj_id: PrimDef::None.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let range = unifier.add_ty(TypeEnum::TObj { let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range, obj_id: PrimDef::Range.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let str = unifier.add_ty(TypeEnum::TObj { let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str, obj_id: PrimDef::Str.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let exception = unifier.add_ty(TypeEnum::TObj { let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception, obj_id: PrimDef::Exception.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let uint32 = unifier.add_ty(TypeEnum::TObj { let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32, obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let uint64 = unifier.add_ty(TypeEnum::TObj { let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64, obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let option = unifier.add_ty(TypeEnum::TObj { let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option, obj_id: PrimDef::Option.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray, obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });

View File

@ -1,12 +1,12 @@
use indexmap::IndexMap;
use itertools::Itertools;
use std::cell::RefCell; use std::cell::RefCell;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::Display; use std::fmt::Display;
use std::iter::zip;
use std::rc::Rc; use std::rc::Rc;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet}; use std::{borrow::Cow, collections::HashSet};
use std::iter::zip;
use indexmap::IndexMap;
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef}; use nac3parser::ast::{Location, StrRef};
@ -61,7 +61,7 @@ pub enum RecordKey {
} }
impl Type { impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching /// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`. /// just to get the field `obj_id`.
#[must_use] #[must_use]
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> { pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
@ -250,9 +250,9 @@ impl Unifier {
} }
/// Returns the [`UnificationTable`] associated with this `Unifier`. /// Returns the [`UnificationTable`] associated with this `Unifier`.
/// ///
/// # Safety /// # Safety
/// ///
/// The use of this function is discouraged under most circumstances. Only use this function if /// The use of this function is discouraged under most circumstances. Only use this function if
/// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to /// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to
/// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing /// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing
@ -379,7 +379,17 @@ impl Unifier {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
let range = range.to_vec(); let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id) (
self.add_ty(TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
}),
id,
)
} }
/// Returns a fresh type representing a constant generic variable with the given underlying type `ty`. /// Returns a fresh type representing a constant generic variable with the given underlying type `ty`.
@ -391,19 +401,22 @@ impl Unifier {
) -> (Type, u32) { ) -> (Type, u32) {
let id = self.var_id + 1; let id = self.var_id + 1;
self.var_id += 1; self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id) (
self.add_ty(TypeEnum::TVar {
id,
range: vec![ty],
fields: None,
name,
loc,
is_const_generic: true,
}),
id,
)
} }
/// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`. /// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`.
pub fn get_fresh_literal( pub fn get_fresh_literal(&mut self, values: Vec<SymbolValue>, loc: Option<Location>) -> Type {
&mut self, let ty_enum = TypeEnum::TLiteral { values: values.into_iter().dedup().collect(), loc };
values: Vec<SymbolValue>,
loc: Option<Location>,
) -> Type {
let ty_enum = TypeEnum::TLiteral {
values: values.into_iter().dedup().collect(),
loc
};
self.add_ty(ty_enum) self.add_ty(ty_enum)
} }
@ -423,7 +436,9 @@ impl Unifier {
Some( Some(
range range
.iter() .iter()
.flat_map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])) .flat_map(|ty| {
self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])
})
.collect_vec(), .collect_vec(),
) )
} }
@ -479,7 +494,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool { pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*; use TypeEnum::*;
match &*self.get_ty(a) { match &*self.get_ty(a) {
TRigidVar { .. } TRigidVar { .. }
| TLiteral { .. } | TLiteral { .. }
// functions are instantiated for each call sites, so the function type can contain // functions are instantiated for each call sites, so the function type can contain
// type variables. // type variables.
@ -487,7 +502,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false, TCall { .. } => false,
TList { ty } TList { ty }
| TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
@ -526,9 +541,7 @@ impl Unifier {
let instantiated = self.instantiate_fun(b, signature); let instantiated = self.instantiate_fun(b, signature);
let r = self.get_ty(instantiated); let r = self.get_ty(instantiated);
let r = r.as_ref(); let r = r.as_ref();
let TypeEnum::TFunc(signature) = r else { let TypeEnum::TFunc(signature) = r else { unreachable!() };
unreachable!()
};
// we check to make sure that all required arguments (those without default // we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice. // arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec(); let mut required = required.to_vec();
@ -555,13 +568,10 @@ impl Unifier {
if let Some(i) = required.iter().position(|v| v == k) { if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i); required.remove(i);
} }
let i = all_names let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
.iter() self.restore_snapshot();
.position(|v| &v.0 == k) TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
.ok_or_else(|| { })?;
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i); let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| { self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot(); self.restore_snapshot();
@ -627,8 +637,17 @@ impl Unifier {
}; };
match (&*ty_a, &*ty_b) { match (&*ty_a, &*ty_b) {
( (
TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. }, TVar {
TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. }, fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, ..
},
TVar {
fields: fields2,
id: id2,
name: name2,
loc: loc2,
is_const_generic: false,
..
},
) => { ) => {
let new_fields = match (fields1, fields2) { let new_fields = match (fields1, fields2) {
(None, None) => None, (None, None) => None,
@ -750,7 +769,10 @@ impl Unifier {
self.set_a_to_b(a, x); self.set_a_to_b(a, x);
} }
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => { (
TVar { id: id1, range: ty1, is_const_generic: true, .. },
TVar { id: id2, range: ty2, .. },
) => {
let ty1 = ty1[0]; let ty1 = ty1[0];
let ty2 = ty2[0]; let ty2 = ty2[0];
@ -765,17 +787,17 @@ impl Unifier {
assert_eq!(tys.len(), 1); assert_eq!(tys.len(), 1);
assert_eq!(values.len(), 1); assert_eq!(values.len(), 1);
let primitives = &self.primitive_store let primitives =
.expect("Expected PrimitiveStore to be present"); &self.primitive_store.expect("Expected PrimitiveStore to be present");
let ty = tys[0]; let ty = tys[0];
let value= &values[0]; let value = &values[0];
let value_ty = value.get_type(primitives, self); let value_ty = value.get_type(primitives, self);
// If the types don't match, try to implicitly promote integers // If the types don't match, try to implicitly promote integers
if !self.unioned(ty, value_ty) { if !self.unioned(ty, value_ty) {
let Ok(num_val) = i128::try_from(value.clone()) else { let Ok(num_val) = i128::try_from(value.clone()) else {
return Self::incompatible_types(a, b) return Self::incompatible_types(a, b);
}; };
let can_convert = if self.unioned(ty, primitives.int32) { let can_convert = if self.unioned(ty, primitives.int32) {
@ -791,7 +813,7 @@ impl Unifier {
}; };
if !can_convert { if !can_convert {
return Self::incompatible_types(a, b) return Self::incompatible_types(a, b);
} }
} }
@ -801,22 +823,12 @@ impl Unifier {
(TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => {
for (v1, v2) in zip(val1, val2) { for (v1, v2) in zip(val1, val2) {
if v1 != v2 { if v1 != v2 {
let symbol_value_to_int = |value: &SymbolValue| -> Option<i128> {
match value {
SymbolValue::I32(v) => Some(*v as i128),
SymbolValue::I64(v) => Some(*v as i128),
SymbolValue::U32(v) => Some(*v as i128),
SymbolValue::U64(v) => Some(*v as i128),
_ => None,
}
};
// Try performing integer promotion on literals // Try performing integer promotion on literals
let v1i = symbol_value_to_int(v1); let v1i = i128::try_from(v1.clone()).ok();
let v2i = symbol_value_to_int(v2); let v2i = i128::try_from(v2.clone()).ok();
if v1i != v2i { if v1i != v2i {
return Self::incompatible_types(a, b) return Self::incompatible_types(a, b);
} }
} }
} }
@ -1287,8 +1299,8 @@ impl Unifier {
mapping: &VarMap, mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>, cache: &mut HashMap<Type, Option<Type>>,
) -> Option<IndexMapping<K>> ) -> Option<IndexMapping<K>>
where where
K: std::hash::Hash + Eq + Clone, K: std::hash::Hash + Eq + Clone,
{ {
let mut map2 = None; let mut map2 = None;
for (k, v) in map { for (k, v) in map {

View File

@ -45,9 +45,9 @@ impl Unifier {
} }
} }
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where where
K: std::hash::Hash + Eq + Clone K: std::hash::Hash + Eq + Clone,
{ {
if map1.len() != map2.len() { if map1.len() != map2.len() {
return false; return false;
@ -342,16 +342,12 @@ fn test_recursive_subst() {
with_fields(&mut env.unifier, foo_id, |_unifier, fields| { with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true)); fields.insert("rec".into(), (foo_id, true));
}); });
let TypeEnum::TObj { params, .. } = &*foo_ty else { let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
unreachable!()
};
let mapping = params.iter().map(|(id, _)| (*id, int)).collect(); let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap(); let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated); let instantiated_ty = env.unifier.get_ty(instantiated);
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
unreachable!()
};
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int)); assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated)); assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
} }
@ -477,7 +473,8 @@ fn test_typevar_range() {
assert_eq!( assert_eq!(
env.unify(a_list, int_list), env.unify(a_list, int_list),
Err("Incompatible types: list[typevar22] and list[0]\ Err("Incompatible types: list[typevar22] and list[0]\
\n\nNotes:\n typevar22 {1}".into()) \n\nNotes:\n typevar22 {1}"
.into())
); );
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0; let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
@ -505,7 +502,10 @@ fn test_rigid_var() {
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string())); assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
env.unifier.unify(list_a, list_x).unwrap(); env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string())); assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: list[typevar2] and list[0]".to_string())
);
env.unifier.replace_rigid_var(a, int); env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap(); env.unifier.unify(list_x, list_int).unwrap();

View File

@ -16,21 +16,10 @@ pub struct UnificationTable<V> {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum Action<V> { enum Action<V> {
Parent { Parent { key: usize, original_parent: usize },
key: usize, Value { key: usize, original_value: Option<V> },
original_parent: usize, Rank { key: usize, original_rank: u32 },
}, Marker { generation: u32 },
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
} }
impl<V> Default for UnificationTable<V> { impl<V> Default for UnificationTable<V> {
@ -41,7 +30,13 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> { impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> { pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 } UnificationTable {
parents: Vec::new(),
ranks: Vec::new(),
values: Vec::new(),
log: Vec::new(),
generation: 0,
}
} }
pub fn new_key(&mut self, v: V) -> UnificationKey { pub fn new_key(&mut self, v: V) -> UnificationKey {
@ -125,7 +120,10 @@ impl<V> UnificationTable<V> {
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) { pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error"); assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot restoration error"
);
for action in self.log.drain(log_len - 1..).rev() { for action in self.log.drain(log_len - 1..).rev() {
match action { match action {
Action::Parent { key, original_parent } => { Action::Parent { key, original_parent } => {
@ -145,7 +143,10 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) { pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot; let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error"); assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error"); assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot discard error"
);
self.log.clear(); self.log.clear();
} }
} }
@ -159,11 +160,23 @@ where
.enumerate() .enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None }) .map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect(); .collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: self.parents.clone(),
ranks: self.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> { pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect(); let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 } UnificationTable {
parents: table.parents.clone(),
ranks: table.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
} }
} }

View File

@ -32,7 +32,6 @@ pub struct DwarfReader<'a> {
} }
impl<'a> DwarfReader<'a> { impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader { pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr } DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
} }
@ -60,7 +59,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8; let mut byte: u8;
loop { loop {
byte = self.read_u8(); byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift; result |= u64::from(byte & 0x7F) << shift;
shift += 7; shift += 7;
if byte & 0x80 == 0 { if byte & 0x80 == 0 {
break; break;
@ -75,7 +74,7 @@ impl<'a> DwarfReader<'a> {
let mut byte: u8; let mut byte: u8;
loop { loop {
byte = self.read_u8(); byte = self.read_u8();
result |= ((byte & 0x7F) as u64) << shift; result |= u64::from(byte & 0x7F) << shift;
shift += 7; shift += 7;
if byte & 0x80 == 0 { if byte & 0x80 == 0 {
break; break;
@ -157,10 +156,9 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
} }
match encoding & 0x0F { match encoding & 0x0F {
DW_EH_PE_absptr => Ok(reader.read_u32() as usize), DW_EH_PE_absptr | DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize), DW_EH_PE_uleb128 => Ok(reader.read_uleb128() as usize),
DW_EH_PE_udata2 => Ok(reader.read_u16() as usize), DW_EH_PE_udata2 => Ok(reader.read_u16() as usize),
DW_EH_PE_udata4 => Ok(reader.read_u32() as usize),
DW_EH_PE_udata8 => Ok(reader.read_u64() as usize), DW_EH_PE_udata8 => Ok(reader.read_u64() as usize),
DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize), DW_EH_PE_sleb128 => Ok(reader.read_sleb128() as usize),
DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize), DW_EH_PE_sdata2 => Ok(reader.read_i16() as usize),
@ -170,10 +168,7 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
} }
} }
fn read_encoded_pointer_with_pc( fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
reader: &mut DwarfReader,
encoding: u8,
) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr; let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?; let mut result = read_encoded_pointer(reader, encoding)?;
@ -223,11 +218,10 @@ pub struct EH_Frame<'a> {
} }
impl<'a> EH_Frame<'a> { impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF /// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file. /// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result<EH_Frame, ()> { pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> EH_Frame {
Ok(EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }) EH_Frame { reader: DwarfReader::new(eh_frame_slice, eh_frame_addr) }
} }
/// Returns an [Iterator] over all Call Frame Information (CFI) records. /// Returns an [Iterator] over all Call Frame Information (CFI) records.
@ -235,10 +229,7 @@ impl<'a> EH_Frame<'a> {
let reader = DwarfReader::from_reader(&self.reader, true); let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len(); let len = reader.slice.len();
CFI_Records { CFI_Records { reader, available: len }
reader,
available: len,
}
} }
} }
@ -255,7 +246,6 @@ pub struct CFI_Record<'a> {
} }
impl<'a> CFI_Record<'a> { impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> { pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32(); let length = cie_reader.read_u32();
let fde_reader = match length { let fde_reader = match length {
@ -264,7 +254,7 @@ impl<'a> CFI_Record<'a> {
// length == u32::MAX means that the length is only representable with 64 bits, // length == u32::MAX means that the length is only representable with 64 bits,
// which does not make sense in a system with 32-bit address. // which does not make sense in a system with 32-bit address.
0xFFFFFFFF => unimplemented!(), 0xFFFF_FFFF => unimplemented!(),
_ => { _ => {
let mut fde_reader = DwarfReader::from_reader(cie_reader, false); let mut fde_reader = DwarfReader::from_reader(cie_reader, false);
@ -323,10 +313,7 @@ impl<'a> CFI_Record<'a> {
} }
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit); assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record { Ok(CFI_Record { fde_pointer_encoding, fde_reader })
fde_pointer_encoding,
fde_reader,
})
} }
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI /// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
@ -340,11 +327,7 @@ impl<'a> CFI_Record<'a> {
let reader = self.get_fde_reader(); let reader = self.get_fde_reader();
let len = reader.slice.len(); let len = reader.slice.len();
FDE_Records { FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
pointer_encoding: self.fde_pointer_encoding,
reader,
available: len,
}
} }
} }
@ -371,7 +354,7 @@ impl<'a> Iterator for CFI_Records<'a> {
let length = match length { let length = match length {
// eh_frame with 0-length means the CIE is terminated // eh_frame with 0-length means the CIE is terminated
0 => return None, 0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"), 0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other, other => other,
} as usize; } as usize;
@ -387,7 +370,7 @@ impl<'a> Iterator for CFI_Records<'a> {
// Skip this record if it is a FDE // Skip this record if it is a FDE
if cie_ptr == 0 { if cie_ptr == 0 {
// Rewind back to the start of the CFI Record // Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap()) return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
} }
} }
} }
@ -417,7 +400,7 @@ impl<'a> Iterator for FDE_Records<'a> {
let length = match self.reader.read_u32() { let length = match self.reader.read_u32() {
// eh_frame with 0-length means the CIE is terminated // eh_frame with 0-length means the CIE is terminated
0 => return None, 0 => return None,
0xFFFFFFFF => unimplemented!("CIE entries larger than 4 bytes not supported"), 0xFFFF_FFFF => unimplemented!("CIE entries larger than 4 bytes not supported"),
other => other, other => other,
} as usize; } as usize;
@ -448,7 +431,6 @@ pub struct EH_Frame_Hdr<'a> {
} }
impl<'a> EH_Frame_Hdr<'a> { impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory. /// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
/// ///
/// Load address is not known at this point. /// Load address is not known at this point.
@ -459,15 +441,16 @@ impl<'a> EH_Frame_Hdr<'a> {
) -> EH_Frame_Hdr { ) -> EH_Frame_Hdr {
let mut writer = DwarfWriter::new(eh_frame_hdr_slice); let mut writer = DwarfWriter::new(eh_frame_hdr_slice);
writer.write_u8(1); // version writer.write_u8(1); // version
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr let eh_frame_offset = eh_frame_addr.wrapping_sub(
.wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4)); eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
writer.write_u32(eh_frame_offset); // eh_frame_ptr );
writer.write_u32(0); // `fde_count`, will be written in finalize_fde writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() } EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() }
} }
@ -492,7 +475,10 @@ impl<'a> EH_Frame_Hdr<'a> {
self.fde_writer.write_u32(*init_loc); self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr); self.fde_writer.write_u32(*addr);
} }
LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32); LittleEndian::write_u32(
&mut self.fde_writer.slice[Self::fde_count_offset()..],
self.fdes.len() as u32,
);
} }
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize { pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {
@ -504,7 +490,7 @@ impl<'a> EH_Frame_Hdr<'a> {
// The original length field should be able to hold the entire value. // The original length field should be able to hold the entire value.
// The device memory space is limited to 32-bits addresses anyway. // The device memory space is limited to 32-bits addresses anyway.
let entry_length = reader.read_u32(); let entry_length = reader.read_u32();
if entry_length == 0 || entry_length == 0xFFFFFFFF { if entry_length == 0 || entry_length == 0xFFFF_FFFF {
unimplemented!() unimplemented!()
} }
@ -515,7 +501,7 @@ impl<'a> EH_Frame_Hdr<'a> {
fde_count += 1; fde_count += 1;
} }
reader.offset(entry_length - mem::size_of::<u32>() as u32) reader.offset(entry_length - mem::size_of::<u32>() as u32);
} }
12 + fde_count * 8 12 + fde_count * 8

View File

@ -1,5 +1,5 @@
/* generated from elf.h with rust-bindgen and then manually altered */ /* generated from elf.h with rust-bindgen and then manually altered */
#![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code)] #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, dead_code, clippy::pedantic)]
pub const EI_NIDENT: usize = 16; pub const EI_NIDENT: usize = 16;
pub const EI_MAG0: usize = 0; pub const EI_MAG0: usize = 0;

View File

@ -1,3 +1,26 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::doc_markdown,
clippy::enum_glob_use,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::similar_names,
clippy::struct_field_names,
clippy::too_many_lines,
clippy::wildcard_imports
)]
use dwarf::*; use dwarf::*;
use elf::*; use elf::*;
use std::collections::HashMap; use std::collections::HashMap;
@ -70,45 +93,45 @@ struct SectionRecord<'a> {
data: Vec<u8>, data: Vec<u8>,
} }
fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Result<T, ()> { fn read_unaligned<T: Copy>(data: &[u8], offset: usize) -> Option<T> {
if data.len() < offset + mem::size_of::<T>() { if data.len() < offset + mem::size_of::<T>() {
Err(()) None
} else { } else {
let ptr = data.as_ptr().wrapping_add(offset) as *const T; let ptr = data.as_ptr().wrapping_add(offset).cast();
Ok(unsafe { ptr::read_unaligned(ptr) }) Some(unsafe { ptr::read_unaligned(ptr) })
} }
} }
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Result<&[T], ()> { #[must_use]
pub fn get_ref_slice<T: Copy>(data: &[u8], offset: usize, len: usize) -> Option<&[T]> {
if data.len() < offset + mem::size_of::<T>() * len { if data.len() < offset + mem::size_of::<T>() * len {
Err(()) None
} else { } else {
let ptr = data.as_ptr().wrapping_add(offset) as *const T; let ptr = data.as_ptr().wrapping_add(offset).cast();
Ok(unsafe { slice::from_raw_parts(ptr, len) }) Some(unsafe { slice::from_raw_parts(ptr, len) })
} }
} }
fn from_struct_vec<T>(struct_vec: Vec<T>) -> Vec<u8> { fn from_struct_slice<T>(struct_vec: &[T]) -> Vec<u8> {
let ptr = struct_vec.as_ptr(); let ptr = struct_vec.as_ptr();
unsafe { slice::from_raw_parts(ptr as *const u8, struct_vec.len() * mem::size_of::<T>()) } unsafe { slice::from_raw_parts(ptr.cast(), mem::size_of_val(struct_vec)) }.to_vec()
.to_vec()
} }
fn to_struct_slice<T>(bytes: &[u8]) -> &[T] { fn to_struct_slice<T>(bytes: &[u8]) -> &[T] {
unsafe { slice::from_raw_parts(bytes.as_ptr() as *const T, bytes.len() / mem::size_of::<T>()) } unsafe { slice::from_raw_parts(bytes.as_ptr().cast(), bytes.len() / mem::size_of::<T>()) }
} }
fn to_struct_mut_slice<T>(bytes: &mut [u8]) -> &mut [T] { fn to_struct_mut_slice<T>(bytes: &mut [u8]) -> &mut [T] {
unsafe { unsafe {
slice::from_raw_parts_mut(bytes.as_mut_ptr() as *mut T, bytes.len() / mem::size_of::<T>()) slice::from_raw_parts_mut(bytes.as_mut_ptr().cast(), bytes.len() / mem::size_of::<T>())
} }
} }
fn elf_hash(name: &[u8]) -> u32 { fn elf_hash(name: &[u8]) -> u32 {
let mut h: u32 = 0; let mut h: u32 = 0;
for c in name { for c in name {
h = (h << 4) + *c as u32; h = (h << 4) + u32::from(*c);
let g = h & 0xf0000000; let g = h & 0xf000_0000;
if g != 0 { if g != 0 {
h ^= g >> 24; h ^= g >> 24;
h &= !g; h &= !g;
@ -202,22 +225,26 @@ impl<'a> Linker<'a> {
relocs: &[R], relocs: &[R],
target_section: Elf32_Word, target_section: Elf32_Word,
) -> Result<(), Error> { ) -> Result<(), Error> {
type RelocateFn = dyn Fn(&mut [u8], Elf32_Word);
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<RelocateFn>>,
}
for reloc in relocs { for reloc in relocs {
let sym = match reloc.sym_info() as usize { let sym = match reloc.sym_info() as usize {
STN_UNDEF => None, STN_UNDEF => None,
sym_index => Some( sym_index => {
self.symtab Some(self.symtab.get(sym_index).ok_or("symbol out of bounds of symbol table")?)
.get(sym_index) }
.ok_or("symbol out of bounds of symbol table")?,
),
}; };
let resolve_symbol_addr = let resolve_symbol_addr =
|sym_option: Option<&Elf32_Sym>| -> Result<Elf32_Word, Error> { |sym_option: Option<&Elf32_Sym>| -> Result<Elf32_Word, Error> {
let sym = match sym_option { let Some(sym) = sym_option else { return Ok(0) };
Some(sym) => sym,
None => return Ok(0),
};
match sym.st_shndx { match sym.st_shndx {
SHN_UNDEF => Err(Error::Lookup("undefined symbol")), SHN_UNDEF => Err(Error::Lookup("undefined symbol")),
@ -244,13 +271,6 @@ impl<'a> Linker<'a> {
.ok_or(Error::Parsing("Cannot find section with matching sh_index")) .ok_or(Error::Parsing("Cannot find section with matching sh_index"))
}; };
struct RelocInfo<'a, R> {
pub defined_val: bool,
pub indirect_reloc: Option<&'a R>,
pub pc_relative: bool,
pub relocate: Option<Box<dyn Fn(&mut [u8], Elf32_Word)>>,
}
let classify = |reloc: &R, sym_option: Option<&Elf32_Sym>| -> Option<RelocInfo<R>> { let classify = |reloc: &R, sym_option: Option<&Elf32_Sym>| -> Option<RelocInfo<R>> {
let defined_val = sym_option.map_or(true, |sym| { let defined_val = sym_option.map_or(true, |sym| {
sym.st_shndx != SHN_UNDEF || ELF32_ST_BIND(sym.st_info) == STB_LOCAL sym.st_shndx != SHN_UNDEF || ELF32_ST_BIND(sym.st_info) == STB_LOCAL
@ -262,7 +282,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: true, pc_relative: true,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value) LittleEndian::write_u32(target_word, value);
})), })),
}), }),
@ -273,9 +293,9 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32( LittleEndian::write_u32(
target_word, target_word,
(LittleEndian::read_u32(target_word) & 0x80000000) (LittleEndian::read_u32(target_word) & 0x8000_0000)
| value & 0x7FFFFFFF, | value & 0x7FFF_FFFF,
) );
})), })),
}), }),
@ -297,8 +317,8 @@ impl<'a> Linker<'a> {
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let auipc_raw = LittleEndian::read_u32(target_word); let auipc_raw = LittleEndian::read_u32(target_word);
let auipc_insn = let auipc_insn =
(auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFFF000); (auipc_raw & 0xFFF) | ((value + 0x800) & 0xFFFF_F000);
LittleEndian::write_u32(target_word, auipc_insn) LittleEndian::write_u32(target_word, auipc_insn);
})), })),
}) })
} }
@ -308,15 +328,14 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: true, pc_relative: true,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(target_word, value) LittleEndian::write_u32(target_word, value);
})), })),
}), }),
R_RISCV_PCREL_LO12_I => { R_RISCV_PCREL_LO12_I => {
let expected_offset = sym_option.map_or(0, |sym| sym.st_value); let expected_offset = sym_option.map_or(0, |sym| sym.st_value);
let indirect_reloc = relocs let indirect_reloc =
.iter() relocs.iter().find(|reloc| reloc.offset() == expected_offset)?;
.find(|reloc| reloc.offset() == expected_offset)?;
Some(RelocInfo { Some(RelocInfo {
defined_val: { defined_val: {
let indirect_sym = let indirect_sym =
@ -330,14 +349,14 @@ impl<'a> Linker<'a> {
// Here, we convert to direct addressing // Here, we convert to direct addressing
// GOT reloc (indirect) -> lw + addi // GOT reloc (indirect) -> lw + addi
// PCREL reloc (direct) -> addi // PCREL reloc (direct) -> addi
let (lo_opcode, lo_funct3) = (0b0010011, 0b000); let (lo_opcode, lo_funct3) = (0b001_0011, 0b000);
let addi_lw_raw = LittleEndian::read_u32(target_word); let addi_lw_raw = LittleEndian::read_u32(target_word);
let addi_insn = lo_opcode let addi_insn = lo_opcode
| (addi_lw_raw & 0xF8F80) | (addi_lw_raw & 0xF8F80)
| (lo_funct3 << 12) | (lo_funct3 << 12)
| ((value & 0xFFF) << 20); | ((value & 0xFFF) << 20);
LittleEndian::write_u32(target_word, addi_insn) LittleEndian::write_u32(target_word, addi_insn);
})), })),
}) })
} }
@ -354,10 +373,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32( LittleEndian::write_u32(target_word, value);
target_word,
value,
)
})), })),
}), }),
@ -367,7 +383,7 @@ impl<'a> Linker<'a> {
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word); let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_add(value)) LittleEndian::write_u32(target_word, old_value.wrapping_add(value));
})), })),
}), }),
@ -377,7 +393,7 @@ impl<'a> Linker<'a> {
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
let old_value = LittleEndian::read_u32(target_word); let old_value = LittleEndian::read_u32(target_word);
LittleEndian::write_u32(target_word, old_value.wrapping_sub(value)) LittleEndian::write_u32(target_word, old_value.wrapping_sub(value));
})), })),
}), }),
@ -386,10 +402,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None, indirect_reloc: None,
pc_relative: false, pc_relative: false,
relocate: Some(Box::new(|target_word, value| { relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u16( LittleEndian::write_u16(target_word, value as u16);
target_word,
value as u16,
)
})), })),
}), }),
@ -402,7 +415,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16( LittleEndian::write_u16(
target_word, target_word,
old_value.wrapping_add(value as u16), old_value.wrapping_add(value as u16),
) );
})), })),
}), }),
@ -415,7 +428,7 @@ impl<'a> Linker<'a> {
LittleEndian::write_u16( LittleEndian::write_u16(
target_word, target_word,
old_value.wrapping_sub(value as u16), old_value.wrapping_sub(value as u16),
) );
})), })),
}), }),
@ -497,7 +510,7 @@ impl<'a> Linker<'a> {
if let Some(relocate) = reloc_info.relocate { if let Some(relocate) = reloc_info.relocate {
let target_word = &mut target_sec_image[reloc.offset() as usize..]; let target_word = &mut target_sec_image[reloc.offset() as usize..];
relocate(target_word, value) relocate(target_word, value);
} else { } else {
self.rela_dyn_relas.push(Elf32_Rela { self.rela_dyn_relas.push(Elf32_Rela {
r_offset: rela_off, r_offset: rela_off,
@ -545,16 +558,18 @@ impl<'a> Linker<'a> {
let eh_frame_slice = eh_frame_rec.data.as_slice(); let eh_frame_slice = eh_frame_rec.data.as_slice();
// Prepare a new buffer to dodge borrow check // Prepare a new buffer to dodge borrow check
let mut eh_frame_hdr_vec: Vec<u8> = vec![0; eh_frame_hdr_rec.shdr.sh_size as usize]; let mut eh_frame_hdr_vec: Vec<u8> = vec![0; eh_frame_hdr_rec.shdr.sh_size as usize];
let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset) let eh_frame = EH_Frame::new(eh_frame_slice, eh_frame_rec.shdr.sh_offset);
.map_err(|()| "cannot read EH frame")?;
let mut eh_frame_hdr = EH_Frame_Hdr::new( let mut eh_frame_hdr = EH_Frame_Hdr::new(
eh_frame_hdr_vec.as_mut_slice(), eh_frame_hdr_vec.as_mut_slice(),
eh_frame_hdr_rec.shdr.sh_offset, eh_frame_hdr_rec.shdr.sh_offset,
eh_frame_rec.shdr.sh_offset, eh_frame_rec.shdr.sh_offset,
); );
eh_frame.cfi_records() eh_frame.cfi_records().flat_map(|cfi| cfi.fde_records()).for_each(&mut |(
.flat_map(|cfi| cfi.fde_records()) init_pos,
.for_each(&mut |(init_pos, virt_addr)| eh_frame_hdr.add_fde(init_pos, virt_addr)); virt_addr,
)| {
eh_frame_hdr.add_fde(init_pos, virt_addr);
});
// Sort FDE entries in .eh_frame_hdr // Sort FDE entries in .eh_frame_hdr
eh_frame_hdr.finalize_fde(); eh_frame_hdr.finalize_fde();
@ -568,55 +583,129 @@ impl<'a> Linker<'a> {
} }
pub fn ld(data: &'a [u8]) -> Result<Vec<u8>, Error> { pub fn ld(data: &'a [u8]) -> Result<Vec<u8>, Error> {
let ehdr = read_unaligned::<Elf32_Ehdr>(data, 0).map_err(|()| "cannot read ELF header")?; fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32 | R_ARM_PREL31 | R_ARM_TARGET2)
| (
Isa::RiscV32,
R_RISCV_CALL_PLT | R_RISCV_PCREL_HI20 | R_RISCV_GOT_HI20 | R_RISCV_32_PCREL
| R_RISCV_SET32 | R_RISCV_ADD32 | R_RISCV_SUB32 | R_RISCV_SET16
| R_RISCV_ADD16 | R_RISCV_SUB16 | R_RISCV_SET8 | R_RISCV_ADD8
| R_RISCV_SUB8 | R_RISCV_SET6 | R_RISCV_SUB6,
) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
let Some(ehdr) = read_unaligned::<Elf32_Ehdr>(data, 0) else {
Err("cannot read ELF header")?
};
let isa = match ehdr.e_machine { let isa = match ehdr.e_machine {
EM_ARM => Isa::CortexA9, EM_ARM => Isa::CortexA9,
EM_RISCV => Isa::RiscV32, EM_RISCV => Isa::RiscV32,
_ => return Err(Error::Parsing("unsupported architecture")), _ => return Err(Error::Parsing("unsupported architecture")),
}; };
let shdrs = get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize) let Some(shdrs) =
.map_err(|()| "cannot read section header table")?; get_ref_slice::<Elf32_Shdr>(data, ehdr.e_shoff as usize, ehdr.e_shnum as usize)
else {
Err("cannot read section header table")?
};
// Read .strtab // Read .strtab
let strtab_shdr = shdrs[ehdr.e_shstrndx as usize]; let strtab_shdr = shdrs[ehdr.e_shstrndx as usize];
let strtab = let Some(strtab) =
get_ref_slice::<u8>(data, strtab_shdr.sh_offset as usize, strtab_shdr.sh_size as usize) get_ref_slice::<u8>(data, strtab_shdr.sh_offset as usize, strtab_shdr.sh_size as usize)
.map_err(|()| "cannot read the string table from data")?; else {
Err("cannot read the string table from data")?
};
// Read .symtab // Read .symtab
let symtab_shdr = shdrs let symtab_shdr = shdrs
.iter() .iter()
.find(|shdr| shdr.sh_type as usize == SHT_SYMTAB) .find(|shdr| shdr.sh_type as usize == SHT_SYMTAB)
.ok_or(Error::Parsing("cannot find the symbol table"))?; .ok_or(Error::Parsing("cannot find the symbol table"))?;
let symtab = get_ref_slice::<Elf32_Sym>( let Some(symtab) = get_ref_slice::<Elf32_Sym>(
data, data,
symtab_shdr.sh_offset as usize, symtab_shdr.sh_offset as usize,
symtab_shdr.sh_size as usize / mem::size_of::<Elf32_Sym>(), symtab_shdr.sh_size as usize / mem::size_of::<Elf32_Sym>(),
) ) else {
.map_err(|()| "cannot read the symbol table from data")?; Err("cannot read the symbol table from data")?
};
// Section table for the .elf paired with the section name // Section table for the .elf paired with the section name
// To be formalized incrementally // To be formalized incrementally
// Very hashmap-like structure, but the order matters, so it is a vector // Very hashmap-like structure, but the order matters, so it is a vector
let elf_shdrs = vec![ let elf_shdrs = vec![SectionRecord {
SectionRecord { shdr: Elf32_Shdr {
shdr: Elf32_Shdr { sh_name: 0,
sh_name: 0, sh_type: 0,
sh_type: 0, sh_flags: 0,
sh_flags: 0, sh_addr: 0,
sh_addr: 0, sh_offset: 0,
sh_offset: 0, sh_size: 0,
sh_size: 0, sh_link: 0,
sh_link: 0, sh_info: 0,
sh_info: 0, sh_addralign: 0,
sh_addralign: 0, sh_entsize: 0,
sh_entsize: 0,
},
name: "",
data: vec![0; 0],
}, },
]; name: "",
data: vec![0; 0],
}];
let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5; let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5;
// Image of the linked dynamic library, to be formalized incrementally // Image of the linked dynamic library, to be formalized incrementally
@ -752,21 +841,27 @@ impl<'a> Linker<'a> {
($shdr: expr, $stmt: expr) => { ($shdr: expr, $stmt: expr) => {
match $shdr.sh_type as usize { match $shdr.sh_type as usize {
SHT_RELA => { SHT_RELA => {
let relocs = get_ref_slice::<Elf32_Rela>( let Some(relocs) = get_ref_slice::<Elf32_Rela>(
data, data,
$shdr.sh_offset as usize, $shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rela>(), $shdr.sh_size as usize / mem::size_of::<Elf32_Rela>(),
) ) else {
.map_err(|()| "cannot parse relocations")?; Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
$stmt(relocs) $stmt(relocs)
} }
SHT_REL => { SHT_REL => {
let relocs = get_ref_slice::<Elf32_Rel>( let Some(relocs) = get_ref_slice::<Elf32_Rel>(
data, data,
$shdr.sh_offset as usize, $shdr.sh_offset as usize,
$shdr.sh_size as usize / mem::size_of::<Elf32_Rel>(), $shdr.sh_size as usize / mem::size_of::<Elf32_Rel>(),
) ) else {
.map_err(|()| "cannot parse relocations")?; Err("cannot parse relocations")?
};
#[allow(clippy::redundant_closure_call)]
$stmt(relocs) $stmt(relocs)
} }
_ => unreachable!(), _ => unreachable!(),
@ -774,84 +869,6 @@ impl<'a> Linker<'a> {
}; };
} }
fn allocate_rela_dyn<R: Relocatable>(
linker: &Linker,
relocs: &[R],
) -> Result<(usize, Vec<u32>), Error> {
let mut alloc_size = 0;
let mut rela_dyn_sym_indices = Vec::new();
for reloc in relocs {
if reloc.sym_info() as usize == STN_UNDEF {
continue;
}
let sym: &Elf32_Sym = linker
.symtab
.get(reloc.sym_info() as usize)
.ok_or("symbol out of bounds of symbol table")?;
match (linker.isa, reloc.type_info()) {
// Absolute address relocations
// A runtime relocation is needed to find the loading address
(Isa::CortexA9, R_ARM_ABS32) | (Isa::RiscV32, R_RISCV_32) => {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// Relative address relocations
// Relay the relocation to the runtime linker only if the symbol is not defined
(Isa::CortexA9, R_ARM_REL32)
| (Isa::CortexA9, R_ARM_PREL31)
| (Isa::CortexA9, R_ARM_TARGET2)
| (Isa::RiscV32, R_RISCV_CALL_PLT)
| (Isa::RiscV32, R_RISCV_PCREL_HI20)
| (Isa::RiscV32, R_RISCV_GOT_HI20)
| (Isa::RiscV32, R_RISCV_32_PCREL)
| (Isa::RiscV32, R_RISCV_SET32)
| (Isa::RiscV32, R_RISCV_ADD32)
| (Isa::RiscV32, R_RISCV_SUB32)
| (Isa::RiscV32, R_RISCV_SET16)
| (Isa::RiscV32, R_RISCV_ADD16)
| (Isa::RiscV32, R_RISCV_SUB16)
| (Isa::RiscV32, R_RISCV_SET8)
| (Isa::RiscV32, R_RISCV_ADD8)
| (Isa::RiscV32, R_RISCV_SUB8)
| (Isa::RiscV32, R_RISCV_SET6)
| (Isa::RiscV32, R_RISCV_SUB6) => {
if ELF32_ST_BIND(sym.st_info) == STB_GLOBAL && sym.st_shndx == SHN_UNDEF {
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
// RISC-V: Lower 12-bits relocations
// If the upper 20-bits relocation cannot be resolved,
// this relocation will be relayed to the runtime linker.
(Isa::RiscV32, R_RISCV_PCREL_LO12_I) => {
// Find the HI20 relocation
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == sym.st_value)
.ok_or("malformatted LO12 relocation")?;
let indirect_sym = linker.symtab[indirect_reloc.sym_info() as usize];
if ELF32_ST_BIND(indirect_sym.st_info) == STB_GLOBAL
&& indirect_sym.st_shndx == SHN_UNDEF
{
alloc_size += mem::size_of::<Elf32_Rela>(); // FIXME: RELA vs REL
rela_dyn_sym_indices.push(reloc.sym_info());
}
}
_ => {
println!("Relocation type 0x{:X?} is not supported", reloc.type_info());
unimplemented!()
}
}
}
Ok((alloc_size, rela_dyn_sym_indices))
}
for shdr in shdrs for shdr in shdrs
.iter() .iter()
.filter(|shdr| shdr.sh_type as usize == SHT_REL || shdr.sh_type as usize == SHT_RELA) .filter(|shdr| shdr.sh_type as usize == SHT_REL || shdr.sh_type as usize == SHT_RELA)
@ -879,7 +896,7 @@ impl<'a> Linker<'a> {
} }
// Avoid symbol duplication // Avoid symbol duplication
rela_dyn_sym_indices.sort(); rela_dyn_sym_indices.sort_unstable();
rela_dyn_sym_indices.dedup(); rela_dyn_sym_indices.dedup();
if rela_dyn_size != 0 { if rela_dyn_size != 0 {
@ -1010,7 +1027,9 @@ impl<'a> Linker<'a> {
let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()]; let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()];
let mut hash_chain: Vec<u32> = vec![0; dynsym.len()]; let mut hash_chain: Vec<u32> = vec![0; dynsym.len()];
for (sym_index, (str_start, str_end)) in dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) { for (sym_index, (str_start, str_end)) in
dynsym_names.iter().enumerate().take(dynsym.len()).skip(1)
{
let hash = elf_hash(&dynstr[*str_start..*str_end]); let hash = elf_hash(&dynstr[*str_start..*str_end]);
let mut hash_index = hash as usize % hash_bucket.len(); let mut hash_index = hash as usize % hash_bucket.len();
@ -1062,7 +1081,7 @@ impl<'a> Linker<'a> {
sh_entsize: mem::size_of::<Elf32_Sym>() as Elf32_Word, sh_entsize: mem::size_of::<Elf32_Sym>() as Elf32_Word,
}, },
".dynsym", ".dynsym",
from_struct_vec(dynsym), from_struct_slice(&dynsym),
); );
let hash_elf_index = linker.load_section( let hash_elf_index = linker.load_section(
&Elf32_Shdr { &Elf32_Shdr {
@ -1078,7 +1097,7 @@ impl<'a> Linker<'a> {
sh_entsize: 4, sh_entsize: 4,
}, },
".hash", ".hash",
from_struct_vec(hash), from_struct_slice(&hash),
); );
// Link .rela.dyn header to the .dynsym header // Link .rela.dyn header to the .dynsym header
@ -1177,7 +1196,7 @@ impl<'a> Linker<'a> {
}; };
let dynamic_elf_index = let dynamic_elf_index =
linker.load_section(&dynamic_shdr, ".dynamic", from_struct_vec(dyn_entries)); linker.load_section(&dynamic_shdr, ".dynamic", from_struct_slice(&dyn_entries));
let last_w_sec_elf_index = linker.elf_shdrs.len() - 1; let last_w_sec_elf_index = linker.elf_shdrs.len() - 1;
@ -1253,7 +1272,9 @@ impl<'a> Linker<'a> {
update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section); update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section);
update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section); update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section);
} else { } else {
for (bss_iter_index, &(bss_section_index, section_name)) in bss_index_vec.iter().enumerate() { for (bss_iter_index, &(bss_section_index, section_name)) in
bss_index_vec.iter().enumerate()
{
let shdr = &shdrs[bss_section_index]; let shdr = &shdrs[bss_section_index];
let bss_elf_index = linker.load_section( let bss_elf_index = linker.load_section(
shdr, shdr,
@ -1326,7 +1347,7 @@ impl<'a> Linker<'a> {
// Prepare a STRTAB to hold the names of section headers // Prepare a STRTAB to hold the names of section headers
// Fix the sh_name field of the section headers // Fix the sh_name field of the section headers
let mut shstrtab = Vec::new(); let mut shstrtab = Vec::new();
for shdr_rec in linker.elf_shdrs.iter_mut() { for shdr_rec in &mut linker.elf_shdrs {
let shstrtab_index = shstrtab.len(); let shstrtab_index = shstrtab.len();
shstrtab.extend(shdr_rec.name.as_bytes()); shstrtab.extend(shdr_rec.name.as_bytes());
shstrtab.push(0); shstrtab.push(0);
@ -1367,20 +1388,17 @@ impl<'a> Linker<'a> {
let alignment = (4 - (linker.image.len() % 4)) % 4; let alignment = (4 - (linker.image.len() % 4)) % 4;
let sec_headers_offset = linker.image.len() + alignment; let sec_headers_offset = linker.image.len() + alignment;
linker.image.extend(vec![0; alignment]); linker.image.extend(vec![0; alignment]);
for rec in linker.elf_shdrs.iter() { for rec in &linker.elf_shdrs {
let shdr = rec.shdr; let shdr = rec.shdr;
linker.image.extend(unsafe { linker.image.extend(unsafe {
slice::from_raw_parts( slice::from_raw_parts(ptr::addr_of!(shdr).cast(), mem::size_of::<Elf32_Shdr>())
&shdr as *const Elf32_Shdr as *const u8,
mem::size_of::<Elf32_Shdr>(),
)
}); });
} }
// Update the PHDRs // Update the PHDRs
let phdr_offset = mem::size_of::<Elf32_Ehdr>(); let phdr_offset = mem::size_of::<Elf32_Ehdr>();
unsafe { unsafe {
let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset) as *mut Elf32_Phdr; let phdr_ptr = linker.image.as_mut_ptr().add(phdr_offset).cast();
let phdr_slice = slice::from_raw_parts_mut(phdr_ptr, 5); let phdr_slice = slice::from_raw_parts_mut(phdr_ptr, 5);
// List of program headers: // List of program headers:
// 1. ELF headers & program headers // 1. ELF headers & program headers
@ -1457,7 +1475,7 @@ impl<'a> Linker<'a> {
} }
// Update the EHDR // Update the EHDR
let ehdr_ptr = linker.image.as_mut_ptr() as *mut Elf32_Ehdr; let ehdr_ptr = linker.image.as_mut_ptr().cast();
unsafe { unsafe {
*ehdr_ptr = Elf32_Ehdr { *ehdr_ptr = Elf32_Ehdr {
e_ident: ehdr.e_ident, e_ident: ehdr.e_ident,

View File

@ -1,15 +1,15 @@
use lalrpop_util::ParseError;
use nac3ast::*;
use crate::ast::Ident; use crate::ast::Ident;
use crate::ast::Location; use crate::ast::Location;
use crate::token::Tok;
use crate::error::*; use crate::error::*;
use crate::token::Tok;
use lalrpop_util::ParseError;
use nac3ast::*;
pub fn make_config_comment( pub fn make_config_comment(
com_loc: Location, com_loc: Location,
stmt_loc: Location, stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>, nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident> nac3com_end: Option<Ident>,
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> { ) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() { if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
@ -17,24 +17,25 @@ pub fn make_config_comment(
location: com_loc, location: com_loc,
error: LexicalErrorType::OtherError( error: LexicalErrorType::OtherError(
format!( format!(
"config comment at top must have the same indentation with what it applies (comment at {}, statement at {})", "config comment at top must have the same indentation with what it applies (comment at {com_loc}, statement at {stmt_loc})",
com_loc,
stmt_loc,
) )
) )
} }
}) });
}; };
Ok( Ok(nac3com_above
nac3com_above .into_iter()
.into_iter() .map(|(com, _)| com)
.map(|(com, _)| com) .chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter())) .collect())
.collect()
)
} }
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> { pub fn handle_small_stmt<U>(
stmts: &mut [Stmt<U>],
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
com_above_loc: Location,
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() { if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User { return Err(ParseError::User {
error: LexicalError { error: LexicalError {
@ -47,17 +48,12 @@ pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, To
) )
) )
} }
}) });
} }
apply_config_comments( apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments( apply_config_comments(
stmts.last_mut().unwrap(), stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com]) nac3com_end.map_or_else(Vec::new, |com| vec![com]),
); );
Ok(()) Ok(())
} }
@ -72,7 +68,7 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::AnnAssign { config_comment, .. } | StmtKind::AnnAssign { config_comment, .. }
| StmtKind::Break { config_comment, .. } | StmtKind::Break { config_comment, .. }
| StmtKind::Continue { config_comment, .. } | StmtKind::Continue { config_comment, .. }
| StmtKind::Return { config_comment, .. } | StmtKind::Return { config_comment, .. }
| StmtKind::Raise { config_comment, .. } | StmtKind::Raise { config_comment, .. }
| StmtKind::Import { config_comment, .. } | StmtKind::Import { config_comment, .. }
| StmtKind::ImportFrom { config_comment, .. } | StmtKind::ImportFrom { config_comment, .. }
@ -80,6 +76,8 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::Nonlocal { config_comment, .. } | StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments), | StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => { unreachable!("only small statements should call this function") } _ => {
unreachable!("only small statements should call this function")
}
} }
} }

View File

@ -37,7 +37,7 @@ impl fmt::Display for LexicalErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self { match self {
LexicalErrorType::StringError => write!(f, "Got unexpected string"), LexicalErrorType::StringError => write!(f, "Got unexpected string"),
LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {}", error), LexicalErrorType::FStringError(error) => write!(f, "Got error in f-string: {error}"),
LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"), LexicalErrorType::UnicodeError => write!(f, "Got unexpected unicode"),
LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"), LexicalErrorType::NestingError => write!(f, "Got unexpected nesting"),
LexicalErrorType::IndentationError => { LexicalErrorType::IndentationError => {
@ -59,13 +59,13 @@ impl fmt::Display for LexicalErrorType {
write!(f, "positional argument follows keyword argument") write!(f, "positional argument follows keyword argument")
} }
LexicalErrorType::UnrecognizedToken { tok } => { LexicalErrorType::UnrecognizedToken { tok } => {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
LexicalErrorType::LineContinuationError => { LexicalErrorType::LineContinuationError => {
write!(f, "unexpected character after line continuation character") write!(f, "unexpected character after line continuation character")
} }
LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"), LexicalErrorType::Eof => write!(f, "unexpected EOF while parsing"),
LexicalErrorType::OtherError(msg) => write!(f, "{}", msg), LexicalErrorType::OtherError(msg) => write!(f, "{msg}"),
} }
} }
} }
@ -96,7 +96,7 @@ impl fmt::Display for FStringErrorType {
FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"), FStringErrorType::UnopenedRbrace => write!(f, "Unopened '}}'"),
FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."), FStringErrorType::ExpectedRbrace => write!(f, "Expected '}}' after conversion flag."),
FStringErrorType::InvalidExpression(error) => { FStringErrorType::InvalidExpression(error) => {
write!(f, "Invalid expression: {}", error) write!(f, "Invalid expression: {error}")
} }
FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"), FStringErrorType::InvalidConversionFlag => write!(f, "Invalid conversion flag"),
FStringErrorType::EmptyExpression => write!(f, "Empty expression"), FStringErrorType::EmptyExpression => write!(f, "Empty expression"),
@ -144,36 +144,27 @@ pub enum ParseErrorType {
impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError { impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self { fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
match err { match err {
// TODO: Are there cases where this isn't an EOF? LalrpopError::ExtraToken { token } => {
LalrpopError::InvalidToken { location } => ParseError { ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 }
error: ParseErrorType::Eof, }
location, LalrpopError::User { error } => {
}, ParseError { error: ParseErrorType::Lexical(error.error), location: error.location }
LalrpopError::ExtraToken { token } => ParseError { }
error: ParseErrorType::ExtraToken(token.1),
location: token.0,
},
LalrpopError::User { error } => ParseError {
error: ParseErrorType::Lexical(error.error),
location: error.location,
},
LalrpopError::UnrecognizedToken { token, expected } => { LalrpopError::UnrecognizedToken { token, expected } => {
// Hacky, but it's how CPython does it. See PyParser_AddToken, // Hacky, but it's how CPython does it. See PyParser_AddToken,
// in particular "Only one possible expected token" comment. // in particular "Only one possible expected token" comment.
let expected = if expected.len() == 1 { let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None };
Some(expected[0].clone())
} else {
None
};
ParseError { ParseError {
error: ParseErrorType::UnrecognizedToken(token.1, expected), error: ParseErrorType::UnrecognizedToken(token.1, expected),
location: token.0, location: token.0,
} }
} }
LalrpopError::UnrecognizedEof { location, .. } => ParseError {
error: ParseErrorType::Eof, LalrpopError::UnrecognizedEof { location, .. }
location, // TODO: Are there cases where this isn't an EOF?
}, | LalrpopError::InvalidToken { location } => {
ParseError { error: ParseErrorType::Eof, location }
}
} }
} }
} }
@ -188,7 +179,7 @@ impl fmt::Display for ParseErrorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self { match *self {
ParseErrorType::Eof => write!(f, "Got unexpected EOF"), ParseErrorType::Eof => write!(f, "Got unexpected EOF"),
ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {:?}", tok), ParseErrorType::ExtraToken(ref tok) => write!(f, "Got extraneous token: {tok:?}"),
ParseErrorType::InvalidToken => write!(f, "Got invalid token"), ParseErrorType::InvalidToken => write!(f, "Got invalid token"),
ParseErrorType::UnrecognizedToken(ref tok, ref expected) => { ParseErrorType::UnrecognizedToken(ref tok, ref expected) => {
if *tok == Tok::Indent { if *tok == Tok::Indent {
@ -196,10 +187,10 @@ impl fmt::Display for ParseErrorType {
} else if expected.as_deref() == Some("Indent") { } else if expected.as_deref() == Some("Indent") {
write!(f, "expected an indented block") write!(f, "expected an indented block")
} else { } else {
write!(f, "Got unexpected token {}", tok) write!(f, "Got unexpected token {tok}")
} }
} }
ParseErrorType::Lexical(ref error) => write!(f, "{}", error), ParseErrorType::Lexical(ref error) => write!(f, "{error}"),
} }
} }
} }
@ -207,6 +198,7 @@ impl fmt::Display for ParseErrorType {
impl Error for ParseErrorType {} impl Error for ParseErrorType {}
impl ParseErrorType { impl ParseErrorType {
#[must_use]
pub fn is_indentation_error(&self) -> bool { pub fn is_indentation_error(&self) -> bool {
match self { match self {
ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true, ParseErrorType::Lexical(LexicalErrorType::IndentationError) => true,
@ -216,11 +208,11 @@ impl ParseErrorType {
_ => false, _ => false,
} }
} }
#[must_use]
pub fn is_tab_error(&self) -> bool { pub fn is_tab_error(&self) -> bool {
matches!( matches!(
self, self,
ParseErrorType::Lexical(LexicalErrorType::TabError) ParseErrorType::Lexical(LexicalErrorType::TabError | LexicalErrorType::TabsAfterSpaces)
| ParseErrorType::Lexical(LexicalErrorType::TabsAfterSpaces)
) )
} }
} }

View File

@ -15,10 +15,7 @@ struct FStringParser<'a> {
impl<'a> FStringParser<'a> { impl<'a> FStringParser<'a> {
fn new(source: &'a str, str_location: Location) -> Self { fn new(source: &'a str, str_location: Location) -> Self {
Self { Self { chars: source.chars().peekable(), str_location }
chars: source.chars().peekable(),
str_location,
}
} }
#[inline] #[inline]
@ -133,10 +130,10 @@ impl<'a> FStringParser<'a> {
) )
} else { } else {
Box::new(self.expr(ExprKind::Constant { Box::new(self.expr(ExprKind::Constant {
value: spec_expression.to_owned().into(), value: spec_expression.clone().into(),
kind: None, kind: None,
})) }))
}) });
} }
'(' | '{' | '[' => { '(' | '{' | '[' => {
expression.push(ch); expression.push(ch);
@ -251,17 +248,11 @@ impl<'a> FStringParser<'a> {
} }
if !content.is_empty() { if !content.is_empty() {
values.push(self.expr(ExprKind::Constant { values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None }));
value: content.into(),
kind: None,
}))
} }
let s = match values.len() { let s = match values.len() {
0 => self.expr(ExprKind::Constant { 0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }),
value: String::new().into(),
kind: None,
}),
1 => values.into_iter().next().unwrap(), 1 => values.into_iter().next().unwrap(),
_ => self.expr(ExprKind::JoinedStr { values }), _ => self.expr(ExprKind::JoinedStr { values }),
}; };
@ -270,16 +261,14 @@ impl<'a> FStringParser<'a> {
} }
fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> { fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
let fstring_body = format!("({})", source); let fstring_body = format!("({source})");
parse_expression(&fstring_body) parse_expression(&fstring_body)
} }
/// Parse an fstring from a string, located at a certain position in the sourcecode. /// Parse an fstring from a string, located at a certain position in the sourcecode.
/// In case of errors, we will get the location and the error returned. /// In case of errors, we will get the location and the error returned.
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> { pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
FStringParser::new(source, location) FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location })
.parse()
.map_err(|error| FStringError { error, location })
} }
#[cfg(test)] #[cfg(test)]

View File

@ -54,38 +54,32 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new()); let mut keyword_names = HashSet::with_capacity_and_hasher(func_args.len(), RandomState::new());
for (name, value) in func_args { for (name, value) in func_args {
match name { if let Some((location, name)) = name {
Some((location, name)) => { if let Some(keyword_name) = &name {
if let Some(keyword_name) = &name { if keyword_names.contains(keyword_name) {
if keyword_names.contains(keyword_name) {
return Err(LexicalError {
error: LexicalErrorType::DuplicateKeywordArgumentError,
location,
});
}
keyword_names.insert(keyword_name.clone());
}
keywords.push(ast::Keyword::new(
location,
ast::KeywordData {
arg: name.map(|name| name.into()),
value: Box::new(value),
},
));
}
None => {
// Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::PositionalArgumentError, error: LexicalErrorType::DuplicateKeywordArgumentError,
location: value.location, location,
}); });
} }
args.push(value); keyword_names.insert(keyword_name.clone());
} }
keywords.push(ast::Keyword::new(
location,
ast::KeywordData { arg: name.map(String::into), value: Box::new(value) },
));
} else {
// Allow starred args after keyword arguments.
if !keywords.is_empty() && !is_starred(&value) {
return Err(LexicalError {
error: LexicalErrorType::PositionalArgumentError,
location: value.location,
});
}
args.push(value);
} }
} }
Ok(ArgumentList { args, keywords }) Ok(ArgumentList { args, keywords })

View File

@ -3,12 +3,12 @@
//! This means source code is translated into separate tokens. //! This means source code is translated into separate tokens.
pub use super::token::Tok; pub use super::token::Tok;
use crate::ast::{Location, FileName}; use crate::ast::{FileName, Location};
use crate::error::{LexicalError, LexicalErrorType}; use crate::error::{LexicalError, LexicalErrorType};
use std::char; use std::char;
use std::cmp::Ordering; use std::cmp::Ordering;
use std::str::FromStr;
use std::num::IntErrorKind; use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation; use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start}; use unic_ucd_ident::{is_xid_continue, is_xid_start};
@ -32,20 +32,14 @@ impl IndentationLevel {
if self.spaces <= other.spaces { if self.spaces <= other.spaces {
Ok(Ordering::Less) Ok(Ordering::Less)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Greater => { Ordering::Greater => {
if self.spaces >= other.spaces { if self.spaces >= other.spaces {
Ok(Ordering::Greater) Ok(Ordering::Greater)
} else { } else {
Err(LexicalError { Err(LexicalError { location, error: LexicalErrorType::TabError })
location,
error: LexicalErrorType::TabError,
})
} }
} }
Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)), Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)),
@ -63,7 +57,7 @@ pub struct Lexer<T: Iterator<Item = char>> {
chr1: Option<char>, chr1: Option<char>,
chr2: Option<char>, chr2: Option<char>,
location: Location, location: Location,
config_comment_prefix: Option<&'static str> config_comment_prefix: Option<&'static str>,
} }
pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! { pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! {
@ -136,11 +130,7 @@ where
T: Iterator<Item = char>, T: Iterator<Item = char>,
{ {
pub fn new(source: T) -> Self { pub fn new(source: T) -> Self {
let mut nlh = NewlineHandler { let mut nlh = NewlineHandler { source, chr0: None, chr1: None };
source,
chr0: None,
chr1: None,
};
nlh.shift(); nlh.shift();
nlh.shift(); nlh.shift();
nlh nlh
@ -169,7 +159,7 @@ where
self.shift(); self.shift();
} else { } else {
// Transform MAC EOL into \n // Transform MAC EOL into \n
self.chr0 = Some('\n') self.chr0 = Some('\n');
} }
} else { } else {
break; break;
@ -189,13 +179,13 @@ where
chars: input, chars: input,
at_begin_of_line: true, at_begin_of_line: true,
nesting: 0, nesting: 0,
indentation_stack: vec![Default::default()], indentation_stack: vec![IndentationLevel::default()],
pending: Vec::new(), pending: Vec::new(),
chr0: None, chr0: None,
location: start, location: start,
chr1: None, chr1: None,
chr2: None, chr2: None,
config_comment_prefix: Some(" nac3:") config_comment_prefix: Some(" nac3:"),
}; };
lxr.next_char(); lxr.next_char();
lxr.next_char(); lxr.next_char();
@ -217,11 +207,9 @@ where
let mut saw_f = false; let mut saw_f = false;
loop { loop {
// Detect r"", f"", b"" and u"" // Detect r"", f"", b"" and u""
if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b') | Some('B')) { if !(saw_b || saw_u || saw_f) && matches!(self.chr0, Some('b' | 'B')) {
saw_b = true; saw_b = true;
} else if !(saw_b || saw_r || saw_u || saw_f) } else if !(saw_b || saw_r || saw_u || saw_f) && matches!(self.chr0, Some('u' | 'U')) {
&& matches!(self.chr0, Some('u') | Some('U'))
{
saw_u = true; saw_u = true;
} else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) { } else if !(saw_r || saw_u) && (self.chr0 == Some('r') || self.chr0 == Some('R')) {
saw_r = true; saw_r = true;
@ -287,15 +275,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let value = match i128::from_str_radix(&value_text, radix) { let value = match i128::from_str_radix(&value_text, radix) {
Ok(value) => value, Ok(value) => value,
Err(e) => { Err(e) => match e.kind() {
match e.kind() { IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX, _ => {
_ => return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::OtherError(format!("{:?}", e)), error: LexicalErrorType::OtherError(format!("{e:?}")),
location: start_pos, location: start_pos,
}), })
} }
} },
}; };
Ok((start_pos, Tok::Int { value }, end_pos)) Ok((start_pos, Tok::Int { value }, end_pos))
} }
@ -338,14 +326,7 @@ where
if self.chr0 == Some('j') || self.chr0 == Some('J') { if self.chr0 == Some('j') || self.chr0 == Some('J') {
self.next_char(); self.next_char();
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok(( Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos))
start_pos,
Tok::Complex {
real: 0.0,
imag: value,
},
end_pos,
))
} else { } else {
let end_pos = self.get_pos(); let end_pos = self.get_pos();
Ok((start_pos, Tok::Float { value }, end_pos)) Ok((start_pos, Tok::Float { value }, end_pos))
@ -364,7 +345,7 @@ where
let value = value_text.parse::<i128>().ok(); let value = value_text.parse::<i128>().ok();
let nonzero = match value { let nonzero = match value {
Some(value) => value != 0i128, Some(value) => value != 0i128,
None => true None => true,
}; };
if start_is_zero && nonzero { if start_is_zero && nonzero {
return Err(LexicalError { return Err(LexicalError {
@ -379,7 +360,7 @@ where
/// Consume a sequence of numbers with the given radix, /// Consume a sequence of numbers with the given radix,
/// the digits can be decorated with underscores /// the digits can be decorated with underscores
/// like this: '1_2_3_4' == '1234' /// like this: `'1_2_3_4'` == `'1234'`
fn radix_run(&mut self, radix: u32) -> String { fn radix_run(&mut self, radix: u32) -> String {
let mut value_text = String::new(); let mut value_text = String::new();
@ -412,7 +393,7 @@ where
2 => matches!(c, Some('0'..='1')), 2 => matches!(c, Some('0'..='1')),
8 => matches!(c, Some('0'..='7')), 8 => matches!(c, Some('0'..='7')),
10 => matches!(c, Some('0'..='9')), 10 => matches!(c, Some('0'..='9')),
16 => matches!(c, Some('0'..='9') | Some('a'..='f') | Some('A'..='F')), 16 => matches!(c, Some('0'..='9' | 'a'..='f' | 'A'..='F')),
other => unimplemented!("Radix not implemented: {}", other), other => unimplemented!("Radix not implemented: {}", other),
} }
} }
@ -420,8 +401,8 @@ where
/// Test if we face '[eE][-+]?[0-9]+' /// Test if we face '[eE][-+]?[0-9]+'
fn at_exponent(&self) -> bool { fn at_exponent(&self) -> bool {
match self.chr0 { match self.chr0 {
Some('e') | Some('E') => match self.chr1 { Some('e' | 'E') => match self.chr1 {
Some('+') | Some('-') => matches!(self.chr2, Some('0'..='9')), Some('+' | '-') => matches!(self.chr2, Some('0'..='9')),
Some('0'..='9') => true, Some('0'..='9') => true,
_ => false, _ => false,
}, },
@ -433,19 +414,17 @@ where
fn lex_comment(&mut self) -> Option<Spanned> { fn lex_comment(&mut self) -> Option<Spanned> {
self.next_char(); self.next_char();
// if possibly nac3 pseudocomment, special handling for `# nac3:` // if possibly nac3 pseudocomment, special handling for `# nac3:`
let (mut prefix, mut is_comment) = self let (mut prefix, mut is_comment) =
.config_comment_prefix self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
// for the correct location of config comment // for the correct location of config comment
let mut start_loc = self.location; let mut start_loc = self.location;
start_loc.go_left(); start_loc.go_left();
loop { loop {
match self.chr0 { match self.chr0 {
Some('\n') => return None, Some('\n') | None => return None,
None => return None,
Some(c) => { Some(c) => {
if let (true, Some(p)) = (is_comment, prefix.next()) { if let (true, Some(p)) = (is_comment, prefix.next()) {
is_comment = is_comment && c == p is_comment = is_comment && c == p;
} else { } else {
// done checking prefix, if is comment then return the spanned // done checking prefix, if is comment then return the spanned
if is_comment { if is_comment {
@ -460,22 +439,20 @@ where
return Some(( return Some((
start_loc, start_loc,
Tok::ConfigComment { content: content.trim().into() }, Tok::ConfigComment { content: content.trim().into() },
self.location self.location,
)); ));
} }
} }
} }
} }
self.next_char(); self.next_char();
}; }
} }
fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> { fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> {
let mut p: u32 = 0u32; let mut p: u32 = 0u32;
let unicode_error = LexicalError { let unicode_error =
error: LexicalErrorType::UnicodeError, LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() };
location: self.get_pos(),
};
for i in 1..=literal_number { for i in 1..=literal_number {
match self.next_char() { match self.next_char() {
Some(c) => match c.to_digit(16) { Some(c) => match c.to_digit(16) {
@ -496,7 +473,7 @@ where
octet_content.push(first); octet_content.push(first);
while octet_content.len() < 3 { while octet_content.len() < 3 {
if let Some('0'..='7') = self.chr0 { if let Some('0'..='7') = self.chr0 {
octet_content.push(self.next_char().unwrap()) octet_content.push(self.next_char().unwrap());
} else { } else {
break; break;
} }
@ -530,10 +507,8 @@ where
} }
} }
} }
unicode_names2::character(&name).ok_or(LexicalError { unicode_names2::character(&name)
error: LexicalErrorType::UnicodeError, .ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos })
location: start_pos,
})
} }
fn lex_string( fn lex_string(
@ -566,7 +541,7 @@ where
} else if is_raw { } else if is_raw {
string_content.push('\\'); string_content.push('\\');
if let Some(c) = self.next_char() { if let Some(c) = self.next_char() {
string_content.push(c) string_content.push(c);
} else { } else {
return Err(LexicalError { return Err(LexicalError {
error: LexicalErrorType::StringError, error: LexicalErrorType::StringError,
@ -599,7 +574,7 @@ where
Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?), Some('u') if !is_bytes => string_content.push(self.unicode_literal(4)?),
Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?), Some('U') if !is_bytes => string_content.push(self.unicode_literal(8)?),
Some('N') if !is_bytes => { Some('N') if !is_bytes => {
string_content.push(self.parse_unicode_name()?) string_content.push(self.parse_unicode_name()?);
} }
Some(c) => { Some(c) => {
string_content.push('\\'); string_content.push('\\');
@ -650,20 +625,15 @@ where
let end_pos = self.get_pos(); let end_pos = self.get_pos();
let tok = if is_bytes { let tok = if is_bytes {
Tok::Bytes { Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() }
value: string_content.chars().map(|c| c as u8).collect(),
}
} else { } else {
Tok::String { Tok::String { value: string_content, is_fstring }
value: string_content,
is_fstring,
}
}; };
Ok((start_pos, tok, end_pos)) Ok((start_pos, tok, end_pos))
} }
fn is_identifier_start(&self, c: char) -> bool { fn is_identifier_start(c: char) -> bool {
match c { match c {
'_' | 'a'..='z' | 'A'..='Z' => true, '_' | 'a'..='z' | 'A'..='Z' => true,
'+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false, '+' | '-' | '*' | '/' | '=' | ' ' | '<' | '>' => false,
@ -835,18 +805,14 @@ where
// Check if we have some character: // Check if we have some character:
if let Some(c) = self.chr0 { if let Some(c) = self.chr0 {
// First check identifier: // First check identifier:
if self.is_identifier_start(c) { if Self::is_identifier_start(c) {
let identifier = self.lex_identifier()?; let identifier = self.lex_identifier()?;
self.emit(identifier); self.emit(identifier);
} else if is_emoji_presentation(c) { } else if is_emoji_presentation(c) {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
let tok_end = self.get_pos(); let tok_end = self.get_pos();
self.emit(( self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end));
tok_start,
Tok::Name { name: c.to_string().into() },
tok_end,
));
} else { } else {
self.consume_character(c)?; self.consume_character(c)?;
} }
@ -899,16 +865,13 @@ where
'=' => { '=' => {
let tok_start = self.get_pos(); let tok_start = self.get_pos();
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::EqEqual, tok_end));
self.emit((tok_start, Tok::EqEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::Equal, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::Equal, tok_end));
}
} }
} }
'+' => { '+' => {
@ -934,16 +897,13 @@ where
} }
Some('*') => { Some('*') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::DoubleStarEqual, tok_end));
self.emit((tok_start, Tok::DoubleStarEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::DoubleStar, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleStar, tok_end));
}
} }
} }
_ => { _ => {
@ -963,16 +923,13 @@ where
} }
Some('/') => { Some('/') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::DoubleSlashEqual, tok_end));
self.emit((tok_start, Tok::DoubleSlashEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::DoubleSlash, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::DoubleSlash, tok_end));
}
} }
} }
_ => { _ => {
@ -1141,16 +1098,13 @@ where
match self.chr0 { match self.chr0 {
Some('<') => { Some('<') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::LeftShiftEqual, tok_end));
self.emit((tok_start, Tok::LeftShiftEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::LeftShift, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::LeftShift, tok_end));
}
} }
} }
Some('=') => { Some('=') => {
@ -1170,16 +1124,13 @@ where
match self.chr0 { match self.chr0 {
Some('>') => { Some('>') => {
self.next_char(); self.next_char();
match self.chr0 { if let Some('=') = self.chr0 {
Some('=') => { self.next_char();
self.next_char(); let tok_end = self.get_pos();
let tok_end = self.get_pos(); self.emit((tok_start, Tok::RightShiftEqual, tok_end));
self.emit((tok_start, Tok::RightShiftEqual, tok_end)); } else {
} let tok_end = self.get_pos();
_ => { self.emit((tok_start, Tok::RightShift, tok_end));
let tok_end = self.get_pos();
self.emit((tok_start, Tok::RightShift, tok_end));
}
} }
} }
Some('=') => { Some('=') => {
@ -1439,14 +1390,8 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: "\\\\".to_owned(), is_fstring: false },
value: "\\\\".to_owned(), Tok::String { value: "\\".to_owned(), is_fstring: false },
is_fstring: false,
},
Tok::String {
value: "\\".to_owned(),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1459,27 +1404,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Int { Tok::Int { value: 47i128 },
value: 47i128, Tok::Int { value: 13i128 },
}, Tok::Int { value: 0i128 },
Tok::Int { Tok::Int { value: 123i128 },
value: 13i128,
},
Tok::Int {
value: 0i128,
},
Tok::Int {
value: 123i128,
},
Tok::Float { value: 0.2 }, Tok::Float { value: 0.2 },
Tok::Complex { Tok::Complex { real: 0.0, imag: 2.0 },
real: 0.0, Tok::Complex { real: 0.0, imag: 2.2 },
imag: 2.0,
},
Tok::Complex {
real: 0.0,
imag: 2.2,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1539,21 +1470,13 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::Name { Tok::Name { name: String::from("avariable").into() },
name: String::from("avariable").into(),
},
Tok::Equal, Tok::Equal,
Tok::Int { Tok::Int { value: 99i128 },
value: 99i128
},
Tok::Plus, Tok::Plus,
Tok::Int { Tok::Int { value: 2i128 },
value: 2i128
},
Tok::Minus, Tok::Minus,
Tok::Int { Tok::Int { value: 0i128 },
value: 0i128
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1740,42 +1663,15 @@ class Foo(A, B):
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![
Tok::String { Tok::String { value: String::from("double"), is_fstring: false },
value: String::from("double"), Tok::String { value: String::from("single"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("can't"), is_fstring: false },
}, Tok::String { value: String::from("\\\""), is_fstring: false },
Tok::String { Tok::String { value: String::from("\t\r\n"), is_fstring: false },
value: String::from("single"), Tok::String { value: String::from("\\g"), is_fstring: false },
is_fstring: false, Tok::String { value: String::from("raw\\'"), is_fstring: false },
}, Tok::String { value: String::from("Đ"), is_fstring: false },
Tok::String { Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false },
value: String::from("can't"),
is_fstring: false,
},
Tok::String {
value: String::from("\\\""),
is_fstring: false,
},
Tok::String {
value: String::from("\t\r\n"),
is_fstring: false,
},
Tok::String {
value: String::from("\\g"),
is_fstring: false,
},
Tok::String {
value: String::from("raw\\'"),
is_fstring: false,
},
Tok::String {
value: String::from("Đ"),
is_fstring: false,
},
Tok::String {
value: String::from("\u{80}\u{0}a"),
is_fstring: false,
},
Tok::Newline, Tok::Newline,
] ]
); );
@ -1840,41 +1736,17 @@ class Foo(A, B):
fn test_raw_byte_literal() { fn test_raw_byte_literal() {
let source = r"rb'\x1z'"; let source = r"rb'\x1z'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]);
tokens,
vec![
Tok::Bytes {
value: b"\\x1z".to_vec()
},
Tok::Newline
]
);
let source = r"rb'\\'"; let source = r"rb'\\'";
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline])
tokens,
vec![
Tok::Bytes {
value: b"\\\\".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
fn test_escape_octet() { fn test_escape_octet() {
let source = r##"b'\43a\4\1234'"##; let source = r##"b'\43a\4\1234'"##;
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline])
tokens,
vec![
Tok::Bytes {
value: b"#a\x04S4".to_vec()
},
Tok::Newline
]
)
} }
#[test] #[test]
@ -1883,13 +1755,7 @@ class Foo(A, B):
let tokens = lex_source(source); let tokens = lex_source(source);
assert_eq!( assert_eq!(
tokens, tokens,
vec![ vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline]
Tok::String {
value: "\u{2002}".to_owned(),
is_fstring: false,
},
Tok::Newline
]
) )
} }
} }

View File

@ -15,6 +15,24 @@
//! //!
//! ``` //! ```
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(
clippy::enum_glob_use,
clippy::fn_params_excessive_bools,
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::module_name_repetitions,
clippy::too_many_lines,
clippy::wildcard_imports
)]
#[macro_use] #[macro_use]
extern crate log; extern crate log;
use lalrpop_util::lalrpop_mod; use lalrpop_util::lalrpop_mod;
@ -27,9 +45,16 @@ pub mod lexer;
pub mod mode; pub mod mode;
pub mod parser; pub mod parser;
lalrpop_mod!( lalrpop_mod!(
#[allow(clippy::all)] #[allow(
#[allow(unused)] future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
unused,
clippy::all,
clippy::pedantic
)]
python python
); );
pub mod token;
pub mod config_comment_helper; pub mod config_comment_helper;
pub mod token;

View File

@ -5,6 +5,7 @@
//! parse a whole program, a single statement, or a single //! parse a whole program, a single statement, or a single
//! expression. //! expression.
use nac3ast::Location;
use std::iter; use std::iter;
use crate::ast::{self, FileName}; use crate::ast::{self, FileName};
@ -63,7 +64,7 @@ pub fn parse_program(source: &str, file: FileName) -> Result<ast::Suite, ParseEr
/// ///
/// ``` /// ```
pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> { pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
parse(source, Mode::Expression, Default::default()).map(|top| match top { parse(source, Mode::Expression, FileName::default()).map(|top| match top {
ast::Mod::Expression { body } => *body, ast::Mod::Expression { body } => *body,
_ => unreachable!(), _ => unreachable!(),
}) })
@ -72,12 +73,10 @@ pub fn parse_expression(source: &str) -> Result<ast::Expr, ParseError> {
// Parse a given source code // Parse a given source code
pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> { pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, ParseError> {
let lxr = lexer::make_tokenizer(source, file); let lxr = lexer::make_tokenizer(source, file);
let marker_token = (Default::default(), mode.to_marker(), Default::default()); let marker_token = (Location::default(), mode.to_marker(), Location::default());
let tokenizer = iter::once(Ok(marker_token)).chain(lxr); let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
python::TopParser::new() python::TopParser::new().parse(tokenizer).map_err(ParseError::from)
.parse(tokenizer)
.map_err(ParseError::from)
} }
#[cfg(test)] #[cfg(test)]
@ -163,7 +162,7 @@ class Foo(A, B):
let parse_ast = parse_expression(&source).unwrap(); let parse_ast = parse_expression(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast); insta::assert_debug_snapshot!(parse_ast);
} }
#[test] #[test]
fn test_more_comment() { fn test_more_comment() {
let source = "\ let source = "\
@ -185,7 +184,7 @@ while i < 2: # nac3: 4
3"; 3";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap()); insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
} }
#[test] #[test]
fn test_sample_comment() { fn test_sample_comment() {
let source = "\ let source = "\

View File

@ -1,7 +1,7 @@
//! Different token definitions. //! Different token definitions.
//! Loosely based on token.h from CPython source: //! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast; use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens. /// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -111,15 +111,23 @@ impl fmt::Display for Tok {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Tok::*; use Tok::*;
match self { match self {
Name { name } => write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)), Name { name } => {
Int { value } => if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") }, write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name))
Float { value } => write!(f, "'{}'", value), }
Complex { real, imag } => write!(f, "{}j{}", real, imag), Int { value } => {
if *value == i128::MAX {
write!(f, "'#OFL#'")
} else {
write!(f, "'{value}'")
}
}
Float { value } => write!(f, "'{value}'"),
Complex { real, imag } => write!(f, "{real}j{imag}"),
String { value, is_fstring } => { String { value, is_fstring } => {
if *is_fstring { if *is_fstring {
write!(f, "f")? write!(f, "f")?;
} }
write!(f, "{:?}", value) write!(f, "{value:?}")
} }
Bytes { value } => { Bytes { value } => {
write!(f, "b\"")?; write!(f, "b\"")?;
@ -129,12 +137,16 @@ impl fmt::Display for Tok {
10 => f.write_str("\\n")?, 10 => f.write_str("\\n")?,
13 => f.write_str("\\r")?, 13 => f.write_str("\\r")?,
32..=126 => f.write_char(*i as char)?, 32..=126 => f.write_char(*i as char)?,
_ => write!(f, "\\x{:02x}", i)?, _ => write!(f, "\\x{i:02x}")?,
} }
} }
f.write_str("\"") f.write_str("\"")
} }
ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)), ConfigComment { content } => write!(
f,
"ConfigComment: '{}'",
ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)
),
Newline => f.write_str("Newline"), Newline => f.write_str("Newline"),
Indent => f.write_str("Indent"), Indent => f.write_str("Indent"),
Dedent => f.write_str("Dedent"), Dedent => f.write_str("Dedent"),

View File

@ -9,8 +9,8 @@ use nac3core::{
}; };
use nac3parser::ast::{self, StrRef}; use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, sync::Arc};
use std::collections::HashSet; use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal { pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>, pub id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -63,10 +63,12 @@ impl SymbolResolver for Resolver {
} }
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> { fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).copied() self.0
.ok_or_else(|| HashSet::from([ .id_to_def
format!("Undefined identifier `{id}`"), .lock()
])) .get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))
} }
fn get_string_id(&self, s: &str) -> i32 { fn get_string_id(&self, s: &str) -> i32 {

View File

@ -1,14 +1,22 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::too_many_lines, clippy::wildcard_imports)]
use clap::Parser; use clap::Parser;
use inkwell::{ use inkwell::{
memory_buffer::MemoryBuffer, memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
OptimizationLevel, OptimizationLevel,
}; };
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use std::collections::HashSet; use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{ use nac3core::{
codegen::{ codegen::{
@ -18,7 +26,7 @@ use nac3core::{
symbol_resolver::SymbolResolver, symbol_resolver::SymbolResolver,
toplevel::{ toplevel::{
composer::{ComposerConfig, TopLevelComposer}, composer::{ComposerConfig, TopLevelComposer},
helper::parse_parameter_default_value, helper::parse_parameter_default_value,
type_annotation::*, type_annotation::*,
TopLevelDef, TopLevelDef,
}, },
@ -78,19 +86,18 @@ fn handle_typevar_definition(
primitives: &PrimitiveStore, primitives: &PrimitiveStore,
) -> Result<Type, HashSet<String>> { ) -> Result<Type, HashSet<String>> {
let ExprKind::Call { func, args, .. } = &var.node else { let ExprKind::Call { func, args, .. } = &var.node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "expression {var:?} cannot be handled as a generic parameter in global scope"
"expression {var:?} cannot be handled as a generic parameter in global scope" )]));
),
]))
}; };
match &func.node { match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => { ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node), "Expected string constant for first parameter of `TypeVar`, got {:?}",
])) &args[0].node
)]));
}; };
let generic_name: StrRef = ty_name.to_string().into(); let generic_name: StrRef = ty_name.to_string().into();
@ -104,19 +111,17 @@ fn handle_typevar_definition(
unifier, unifier,
primitives, primitives,
x, x,
HashMap::default(), HashMap::new(),
)?; )?;
get_type_from_type_annotation_kinds( get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
def_list, unifier, &ty, &mut None
)
}) })
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let loc = func.location; let loc = func.location;
if constraints.len() == 1 { if constraints.len() == 1 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("A single constraint is not allowed (at {loc})"), "A single constraint is not allowed (at {loc})"
])) )]));
} }
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
@ -124,18 +129,17 @@ fn handle_typevar_definition(
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => { ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 { if args.len() != 2 {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()), "Expected 2 arguments for `ConstGeneric`, got {}",
])) args.len()
)]));
} }
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else { let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([ return Err(HashSet::from([format!(
format!( "Expected string constant for first parameter of `ConstGeneric`, got {:?}",
"Expected string constant for first parameter of `ConstGeneric`, got {:?}", &args[0].node
&args[0].node )]));
),
]))
}; };
let generic_name: StrRef = ty_name.to_string().into(); let generic_name: StrRef = ty_name.to_string().into();
@ -145,21 +149,18 @@ fn handle_typevar_definition(
unifier, unifier,
primitives, primitives,
&args[1], &args[1],
HashMap::default(), HashMap::new(),
)?;
let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, &ty, &mut None
)?; )?;
let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location; let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0) Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
} }
_ => Err(HashSet::from([ _ => Err(HashSet::from([format!(
format!( "expression {var:?} cannot be handled as a generic parameter in global scope"
"expression {var:?} cannot be handled as a generic parameter in global scope" )])),
),
]))
} }
} }
@ -175,18 +176,12 @@ fn handle_assignment_pattern(
if targets.len() == 1 { if targets.len() == 1 {
match &targets[0].node { match &targets[0].node {
ExprKind::Name { id, .. } => { ExprKind::Name { id, .. } => {
if let Ok(var) = handle_typevar_definition( if let Ok(var) =
value, handle_typevar_definition(value, resolver, def_list, unifier, primitives)
resolver, {
def_list,
unifier,
primitives,
) {
internal_resolver.add_id_type(*id, var); internal_resolver.add_id_type(*id, var);
Ok(()) Ok(())
} else if let Ok(val) = } else if let Ok(val) = parse_parameter_default_value(value, resolver) {
parse_parameter_default_value(value, resolver)
{
internal_resolver.add_module_global(*id, val); internal_resolver.add_module_global(*id, val);
Ok(()) Ok(())
} else { } else {
@ -238,10 +233,7 @@ fn handle_assignment_pattern(
)) ))
} }
} }
_ => Err(format!( _ => Err(format!("unpack of this expression is not supported at {}", value.location)),
"unpack of this expression is not supported at {}",
value.location
)),
} }
} }
} }
@ -250,15 +242,8 @@ fn main() {
const SIZE_T: u32 = usize::BITS; const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse(); let cli = CommandLineArgs::parse();
let CommandLineArgs { let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
file_name, cli;
threads,
opt_level,
emit_llvm,
triple,
mcpu,
target_features,
} = cli;
Target::initialize_all(&InitializationConfig::default()); Target::initialize_all(&InitializationConfig::default());
@ -270,11 +255,9 @@ fn main() {
let target_features = target_features.unwrap_or_default(); let target_features = target_features.unwrap_or_default();
let threads = if is_multithreaded() { let threads = if is_multithreaded() {
if threads == 0 { if threads == 0 {
std::thread::available_parallelism() std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1usize)
.map(|threads| threads.get() as u32)
.unwrap_or(1u32)
} else { } else {
threads threads as usize
} }
} else { } else {
if threads != 1 { if threads != 1 {
@ -308,7 +291,8 @@ fn main() {
class_names: Mutex::default(), class_names: Mutex::default(),
module_globals: Mutex::default(), module_globals: Mutex::default(),
str_store: Mutex::default(), str_store: Mutex::default(),
}.into(); }
.into();
let resolver = let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>; Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
@ -332,13 +316,16 @@ fn main() {
eprintln!("{err}"); eprintln!("{err}");
return; return;
} }
}, }
// allow (and ignore) "from __future__ import annotations" // allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. } StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into()) && names.len() == 1 && names[0].name == "annotations".into() => (), if module == &Some("__future__".into())
&& names.len() == 1
&& names[0].name == "annotations".into() => {}
_ => { _ => {
let (name, def_id, ty) = let (name, def_id, ty) = composer
composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true).unwrap(); .register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
internal_resolver.add_id_def(name, def_id); internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty { if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty); internal_resolver.add_id_type(name, ty);
@ -364,7 +351,8 @@ fn main() {
.unwrap_or_else(|_| panic!("cannot find run() entry point")) .unwrap_or_else(|_| panic!("cannot find run() entry point"))
.0] .0]
.write(); .write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else { let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance
else {
unreachable!() unreachable!()
}; };
instance_to_symbol.insert(String::new(), "run".to_string()); instance_to_symbol.insert(String::new(), "run".to_string());
@ -444,7 +432,8 @@ fn main() {
function_iter = func.get_next_function(); function_iter = func.get_next_function();
} }
let target_machine = llvm_options.target let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level) .create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine"); .expect("couldn't create target machine");

View File

@ -1,3 +1,13 @@
#![deny(
future_incompatible,
let_underscore,
nonstandard_style,
rust_2024_compatibility,
clippy::all
)]
#![warn(clippy::pedantic)]
#![allow(clippy::semicolon_if_nothing_returned, clippy::uninlined_format_args)]
use std::env; use std::env;
static mut NOW: i64 = 0; static mut NOW: i64 = 0;
@ -29,17 +39,17 @@ pub extern "C" fn rtio_get_counter() -> i64 {
#[no_mangle] #[no_mangle]
pub extern "C" fn rtio_output(target: i32, data: i32) { pub extern "C" fn rtio_output(target: i32, data: i32) {
println!("rtio_output @{} target={:04x} data={}", unsafe { NOW }, target, data); println!("rtio_output @{} target={target:04x} data={data}", unsafe { NOW });
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn print_int32(x: i32) { pub extern "C" fn print_int32(x: i32) {
println!("print_int32: {}", x); println!("print_int32: {x}");
} }
#[no_mangle] #[no_mangle]
pub extern "C" fn print_int64(x: i64) { pub extern "C" fn print_int64(x: i64) {
println!("print_int64: {}", x); println!("print_int64: {x}");
} }
#[no_mangle] #[no_mangle]
@ -47,12 +57,11 @@ pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _conte
unimplemented!(); unimplemented!();
} }
fn main() { fn main() {
let filename = env::args().nth(1).unwrap(); let filename = env::args().nth(1).unwrap();
unsafe { unsafe {
let lib = libloading::Library::new(filename).unwrap(); let lib = libloading::Library::new(filename).unwrap();
let func: libloading::Symbol<unsafe extern fn()> = lib.get(b"__modinit__").unwrap(); let func: libloading::Symbol<unsafe extern "C" fn()> = lib.get(b"__modinit__").unwrap();
func() func();
} }
} }