Add nac3/find_ndarray_anomalies.py

This commit is contained in:
lyken 2024-07-21 17:47:55 +08:00
parent c4b5957e3a
commit fdd770b498

View File

@ -0,0 +1,109 @@
#!/usr/bin/env python
# # PLACE ME IN nac3standalone/demo/find_ndarray_anomalies.py AND DO `chmod +x` ON ME!!
from __future__ import annotations
from pathlib import Path
from typing import *
import argparse
import re
import subprocess
class SourceTemplate(NamedTuple):
source: str
def _list_all_functions(self) -> list[str]:
# Find all `def ???()` functions
pattern = re.compile(
r"^def\s+(?P<name>.+?)\s*\(.*?\)\s*:\s*$",
re.MULTILINE
)
results = pattern.findall(self.source)
if not results:
raise RuntimeError("Something is wrong with the regex")
return results
def list_all_test_functions(self) -> Iterator[str]:
# Find all `def test_*()` functions
return [ n for n in self._list_all_functions() if n.startswith("test_") ]
def subst_function(self, func_name: str) -> str:
# Append a `def run()` at the end of `source`
return self.source + f"""
def run() -> int32:
{func_name}()
return 0
"""
@classmethod
def load(cls) -> SourceTemplate:
# What this function does:
# 1. Read src/ndarray.py
# 2. Chop off the `run()` function, leaving behind everything else.
# - This is done by finding the `def run() -> int32:` line and
# remove that line and everything below it.
original_source = Path("src/ndarray.py").read_text()
# Read the function's comment
pattern = re.compile(
r"(?P<source>.+?)^\s+^def\s+run\(\)\s*->\s*int32\s*:\s*$.*",
re.MULTILINE | re.DOTALL
)
regex_result = pattern.fullmatch(original_source)
if regex_result is None:
raise RuntimeError("Regex cannot find `def run() -> int32:`!")
# source is now original_source but without the
# `def run() -> int32:` line and everything below it.
source = regex_result.group("source")
return cls(source=source)
def proc_nac3(
py_path: Path,
size_t: int,
lli: bool,
nac3args: Optional[list[str]] = None
) -> subprocess.CompletedProcess:
# Prepare args
args = ["./run_demo.sh"]
if lli:
args.append("--lli")
args.extend(["--out", "/dev/stdout"])
if nac3args is not None:
args.extend(nac3args)
args.extend(["-s", str(size_t)])
args.append(str(py_path))
# Run
return subprocess.run(args, capture_output=True)
def main():
template = SourceTemplate.load()
tmp_py_path = Path("_ndarray_tmp.py")
for test_func in template.list_all_test_functions():
source = template.subst_function(test_func)
tmp_py_path.write_text(source)
# We only want to find LLVM codegen defects,
# therefore we don't need to run them, and `--lli` doesn't matter.
nac3result = proc_nac3(tmp_py_path, 32, lli=False, nac3args=["--debug"])
if nac3result.returncode != 0:
print(f"\n\n\n>>>>>>>>>>>>>> Error in {test_func}")
# utf8 decode content
# There should be no raw bytes output from nac3core when compiling
error = nac3result.stderr.decode()
print(error)
print("\n\n\nDone")
if __name__ == "__main__":
main()