core: Refactor PrimitiveDefinitionIds into an enum + refactor get_builtins() #408

Merged
sb10q merged 3 commits from issue-385-primdef into master 2024-06-12 15:44:33 +08:00
53 changed files with 6206 additions and 6932 deletions

21
Cargo.lock generated
View File

@ -625,6 +625,8 @@ dependencies = [
"parking_lot",
"rayon",
"regex",
"strum",
"strum_macros",
"test-case",
]
@ -1116,6 +1118,25 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29"
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"proc-macro2",
"quote",
"rustversion",
"syn 2.0.66",
]
[[package]]
name = "syn"
version = "1.0.109"

View File

@ -6,21 +6,20 @@ use nac3core::{
CodeGenContext, CodeGenerator,
},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, GenCall, helper::PRIMITIVE_DEF_IDS},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap}
toplevel::{helper::PrimDef, DefinitionId, GenCall},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, VarMap},
};
use nac3parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
use inkwell::{
context::Context,
module::Linkage,
types::IntType,
values::BasicValueEnum,
AddressSpace,
context::Context, module::Linkage, types::IntType, values::BasicValueEnum, AddressSpace,
};
use pyo3::{PyObject, PyResult, Python, types::{PyDict, PyList}};
use pyo3::{
types::{PyDict, PyList},
PyObject, PyResult, Python,
};
use crate::{symbol_resolver::InnerResolver, timeline::TimeFns};
@ -46,7 +45,7 @@ enum ParallelMode {
///
/// Each function call within the `with` block (except those within a nested `sequential` block)
/// are treated to be executed in parallel.
Deep
Deep,
}
pub struct ArtiqCodeGenerator<'a> {
@ -96,14 +95,13 @@ impl<'a> ArtiqCodeGenerator<'a> {
///
/// Direct-`parallel` block context refers to when the generator is generating statements whose
/// closest parent `with` statement is a `with parallel` block.
fn timeline_reset_start(
&mut self,
ctx: &mut CodeGenContext<'_, '_>
) -> Result<(), String> {
fn timeline_reset_start(&mut self, ctx: &mut CodeGenContext<'_, '_>) -> Result<(), String> {
if let Some(start) = self.start.clone() {
let start_val = self.gen_expr(ctx, &start)?
.unwrap()
.to_basic_value_enum(ctx, self, start.custom.unwrap())?;
let start_val = self.gen_expr(ctx, &start)?.unwrap().to_basic_value_enum(
ctx,
self,
start.custom.unwrap(),
)?;
self.timeline.emit_at_mu(ctx, start_val);
}
@ -129,20 +127,20 @@ impl<'a> ArtiqCodeGenerator<'a> {
store_name: Option<&str>,
) -> Result<(), String> {
if let Some(end) = end {
let old_end = self.gen_expr(ctx, &end)?
.unwrap()
.to_basic_value_enum(ctx, self, end.custom.unwrap())?;
let now = self.timeline.emit_now_mu(ctx);
let max = call_int_smax(
ctx,
old_end.into_int_value(),
now.into_int_value(),
Some("smax")
);
let end_store = self.gen_store_target(
let old_end = self.gen_expr(ctx, &end)?.unwrap().to_basic_value_enum(
ctx,
&end,
store_name.map(|name| format!("{name}.addr")).as_deref())?
self,
end.custom.unwrap(),
)?;
let now = self.timeline.emit_now_mu(ctx);
let max =
call_int_smax(ctx, old_end.into_int_value(), now.into_int_value(), Some("smax"));
let end_store = self
.gen_store_target(
ctx,
&end,
store_name.map(|name| format!("{name}.addr")).as_deref(),
)?
.unwrap();
ctx.builder.build_store(end_store, max).unwrap();
}
@ -164,11 +162,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
}
}
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item=&'c Stmt<Option<Type>>>>(
fn gen_block<'ctx, 'a, 'c, I: Iterator<Item = &'c Stmt<Option<Type>>>>(
&mut self,
ctx: &mut CodeGenContext<'ctx, 'a>,
stmts: I
) -> Result<(), String> where Self: Sized {
stmts: I,
) -> Result<(), String>
where
Self: Sized,
{
// Legacy parallel emits timeline end-update/timeline-reset after each top-level statement
// in the parallel block
if self.parallel_mode == ParallelMode::Legacy {
@ -212,9 +213,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::With { items, body, .. } = &stmt.node else {
unreachable!()
};
let StmtKind::With { items, body, .. } = &stmt.node else { unreachable!() };
if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0];
@ -239,9 +238,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
let old_parallel_mode = self.parallel_mode;
let now = if let Some(old_start) = &old_start {
self.gen_expr(ctx, old_start)?
.unwrap()
.to_basic_value_enum(ctx, self, old_start.custom.unwrap())?
self.gen_expr(ctx, old_start)?.unwrap().to_basic_value_enum(
ctx,
self,
old_start.custom.unwrap(),
)?
} else {
self.timeline.emit_now_mu(ctx)
};
@ -277,9 +278,7 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
node: ExprKind::Name { id: end, ctx: name_ctx.clone() },
custom: Some(ctx.primitives.int64),
};
let end = self
.gen_store_target(ctx, &end_expr, Some("end.addr"))?
.unwrap();
let end = self.gen_store_target(ctx, &end_expr, Some("end.addr"))?.unwrap();
ctx.builder.build_store(end, now).unwrap();
self.end = Some(end_expr);
self.name_counter += 1;
@ -309,10 +308,11 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// set duration
let end_expr = self.end.take().unwrap();
let end_val = self
.gen_expr(ctx, &end_expr)?
.unwrap()
.to_basic_value_enum(ctx, self, end_expr.custom.unwrap())?;
let end_val = self.gen_expr(ctx, &end_expr)?.unwrap().to_basic_value_enum(
ctx,
self,
end_expr.custom.unwrap(),
)?;
// inside a sequential block
if old_start.is_none() {
@ -416,7 +416,7 @@ fn rpc_codegen_callback_fn<'ctx>(
let int32 = ctx.ctx.i32_type();
let tag_ptr_type = ctx.ctx.struct_type(&[ptr_type.into(), size_type.into()], false);
let service_id = int32.const_int(fun.1.0 as u64, false);
let service_id = int32.const_int(fun.1 .0 as u64, false);
// -- setup rpc tags
let mut tag = Vec::new();
if obj.is_some() {
@ -461,7 +461,8 @@ fn rpc_codegen_callback_fn<'ctx>(
let arg_length = args.len() + usize::from(obj.is_some());
let stackptr = call_stacksave(ctx, Some("rpc.stack"));
let args_ptr = ctx.builder
let args_ptr = ctx
.builder
.build_array_alloca(
ptr_type,
ctx.ctx.i32_type().const_int(arg_length as u64, false),
@ -477,10 +478,8 @@ fn rpc_codegen_callback_fn<'ctx>(
}
// default value handling
for k in keys {
mapping.insert(
k.name,
ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into()
);
mapping
.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into());
}
// reorder the parameters
let mut real_params = fun
@ -499,7 +498,8 @@ fn rpc_codegen_callback_fn<'ctx>(
}
for (i, arg) in real_params.iter().enumerate() {
let arg_slot = generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
let arg_slot =
generator.gen_var_alloc(ctx, arg.get_type(), Some(&format!("rpc.arg{i}"))).unwrap();
ctx.builder.build_store(arg_slot, *arg).unwrap();
let arg_slot = ctx.builder.build_bitcast(arg_slot, ptr_type, "rpc.arg").unwrap();
let arg_ptr = unsafe {
@ -508,7 +508,8 @@ fn rpc_codegen_callback_fn<'ctx>(
&[int32.const_int(i as u64, false)],
&format!("rpc.arg{i}"),
)
}.unwrap();
}
.unwrap();
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
}
@ -528,11 +529,7 @@ fn rpc_codegen_callback_fn<'ctx>(
)
});
ctx.builder
.build_call(
rpc_send,
&[service_id.into(), tag_ptr.into(), args_ptr.into()],
"rpc.send",
)
.build_call(rpc_send, &[service_id.into(), tag_ptr.into(), args_ptr.into()], "rpc.send")
.unwrap();
// reclaim stack space used by arguments
@ -575,13 +572,9 @@ fn rpc_codegen_callback_fn<'ctx>(
.build_call_or_invoke(rpc_recv, &[phi.as_basic_value()], "rpc.size.next")
.unwrap()
.into_int_value();
let is_done = ctx.builder
.build_int_compare(
inkwell::IntPredicate::EQ,
int32.const_zero(),
alloc_size,
"rpc.done",
)
let is_done = ctx
.builder
.build_int_compare(inkwell::IntPredicate::EQ, int32.const_zero(), alloc_size, "rpc.done")
.unwrap();
ctx.builder.build_conditional_branch(is_done, tail_bb, alloc_bb).unwrap();
@ -617,9 +610,15 @@ pub fn attributes_writeback(
let mut scratch_buffer = Vec::new();
for val in (*globals).values() {
let val = val.as_ref(py);
let ty = inner_resolver.get_obj_type(py, val, &mut ctx.unifier, &top_levels, &ctx.primitives)?;
let ty = inner_resolver.get_obj_type(
py,
val,
&mut ctx.unifier,
&top_levels,
&ctx.primitives,
)?;
if let Err(ty) = ty {
return Ok(Err(ty))
return Ok(Err(ty));
}
let ty = ty.unwrap();
match &*ctx.unifier.get_ty(ty) {
@ -632,14 +631,19 @@ pub fn attributes_writeback(
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields {
if !is_mutable {
continue
continue;
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name);
values.push((*field_ty, ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)], None)));
values.push((
*field_ty,
ctx.build_gep_and_load(
obj.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
None,
),
));
}
}
if !attributes.is_empty() {
@ -648,33 +652,44 @@ pub fn attributes_writeback(
pydict.set_item("fields", attributes)?;
host_attributes.append(pydict)?;
}
},
}
TypeEnum::TList { ty: elem_ty } => {
if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() {
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
values.push((
ty,
inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(),
));
}
},
}
_ => {}
}
}
let fun = FunSignature {
args: values.iter().enumerate().map(|(i, (ty, _))| FuncArg {
name: i.to_string().into(),
ty: *ty,
default_value: None
}).collect(),
args: values
.iter()
.enumerate()
.map(|(i, (ty, _))| FuncArg {
name: i.to_string().into(),
ty: *ty,
default_value: None,
})
.collect(),
ret: ctx.primitives.none,
vars: VarMap::default()
vars: VarMap::default(),
};
let args: Vec<_> = values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) = rpc_codegen_callback_fn(ctx, None, (&fun, PRIMITIVE_DEF_IDS.int32), args, generator) {
let args: Vec<_> =
values.into_iter().map(|(_, val)| (None, ValueEnum::Dynamic(val))).collect();
if let Err(e) =
rpc_codegen_callback_fn(ctx, None, (&fun, PrimDef::Int32.id()), args, generator)
{
return Ok(Err(e));
}
Ok(Ok(()))
}).unwrap()?;
})
.unwrap()?;
Ok(())
}

View File

@ -14,16 +14,16 @@ use inkwell::{
OptimizationLevel,
};
use itertools::Itertools;
use nac3core::codegen::{CodeGenLLVMOptions, CodeGenTargetMachineOptions, gen_func_impl};
use nac3core::codegen::{gen_func_impl, CodeGenLLVMOptions, CodeGenTargetMachineOptions};
use nac3core::toplevel::builtins::get_exn_constructor;
use nac3core::typecheck::typedef::{TypeEnum, Unifier, VarMap};
use nac3parser::{
ast::{ExprKind, Stmt, StmtKind, StrRef},
parser::parse_program,
};
use pyo3::create_exception;
use pyo3::prelude::*;
use pyo3::{exceptions, types::PyBytes, types::PyDict, types::PySet};
use pyo3::create_exception;
use parking_lot::{Mutex, RwLock};
@ -46,7 +46,7 @@ use tempfile::{self, TempDir};
use crate::codegen::attributes_writeback;
use crate::{
codegen::{rpc_codegen_callback, ArtiqCodeGenerator},
symbol_resolver::{InnerResolver, PythonHelper, Resolver, DeferredEvaluationStore},
symbol_resolver::{DeferredEvaluationStore, InnerResolver, PythonHelper, Resolver},
};
mod codegen;
@ -138,9 +138,7 @@ impl Nac3 {
for mut stmt in parser_result {
let include = match stmt.node {
StmtKind::ClassDef {
ref decorator_list, ref mut body, ref mut bases, ..
} => {
StmtKind::ClassDef { ref decorator_list, ref mut body, ref mut bases, .. } => {
let nac3_class = decorator_list.iter().any(|decorator| {
if let ExprKind::Name { id, .. } = decorator.node {
id.to_string() == "nac3"
@ -160,7 +158,8 @@ impl Nac3 {
if *id == "Exception".into() {
Ok(true)
} else {
let base_obj = module.getattr(py, id.to_string().as_str())?;
let base_obj =
module.getattr(py, id.to_string().as_str())?;
let base_id = id_fn.call1((base_obj,))?.extract()?;
Ok(registered_class_ids.contains(&base_id))
}
@ -341,8 +340,9 @@ impl Nac3 {
let class_obj;
if let StmtKind::ClassDef { name, .. } = &stmt.node {
let class = py_module.getattr(name.to_string().as_str()).unwrap();
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap() &&
class.getattr("artiq_builtin").is_err() {
if issubclass.call1((class, exn_class)).unwrap().extract().unwrap()
&& class.getattr("artiq_builtin").is_err()
{
class_obj = Some(class);
} else {
class_obj = None;
@ -388,12 +388,12 @@ impl Nac3 {
let (name, def_id, ty) = composer
.register_top_level(stmt.clone(), Some(resolver.clone()), path, false)
.map_err(|e| {
CompileError::new_err(format!(
"compilation failed\n----------\n{e}"
))
CompileError::new_err(format!("compilation failed\n----------\n{e}"))
})?;
if let Some(class_obj) = class_obj {
self.exception_ids.write().insert(def_id.0, store_obj.call1(py, (class_obj, ))?.extract(py)?);
self.exception_ids
.write()
.insert(def_id.0, store_obj.call1(py, (class_obj,))?.extract(py)?);
}
match &stmt.node {
@ -470,7 +470,8 @@ impl Nac3 {
exception_ids: self.exception_ids.clone(),
deferred_eval_store: self.deferred_eval_store.clone(),
});
let resolver = Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let resolver =
Arc::new(Resolver(inner_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
let (_, def_id, _) = composer
.register_top_level(synthesized.pop().unwrap(), Some(resolver.clone()), "", false)
.unwrap();
@ -479,8 +480,12 @@ impl Nac3 {
FunSignature { args: vec![], ret: self.primitive.none, vars: VarMap::new() };
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.from_signature(
&mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature);
if let Err(e) = composer.start_analysis(true) {
@ -499,13 +504,11 @@ impl Nac3 {
msg.unwrap_or(e.iter().sorted().join("\n----------\n"))
)))
} else {
Err(CompileError::new_err(
format!(
"compilation failed\n----------\n{}",
e.iter().sorted().join("\n----------\n"),
),
))
}
Err(CompileError::new_err(format!(
"compilation failed\n----------\n{}",
e.iter().sorted().join("\n----------\n"),
)))
};
}
let top_level = Arc::new(composer.make_top_level_context());
@ -533,7 +536,9 @@ impl Nac3 {
py,
(
id.0.into_py(py),
class_def.getattr(py, name.to_string().as_str()).unwrap(),
class_def
.getattr(py, name.to_string().as_str())
.unwrap(),
),
)
.unwrap();
@ -548,7 +553,8 @@ impl Nac3 {
let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition else {
&mut *definition
else {
unreachable!()
};
@ -570,8 +576,12 @@ impl Nac3 {
let mut store = ConcreteTypeStore::new();
let mut cache = HashMap::new();
let signature =
store.from_signature(&mut composer.unifier, &self.primitive, &fun_signature, &mut cache);
let signature = store.from_signature(
&mut composer.unifier,
&self.primitive,
&fun_signature,
&mut cache,
);
let signature = store.add_cty(signature);
let attributes_writeback_task = CodeGenTask {
subst: Vec::default(),
@ -604,23 +614,28 @@ impl Nac3 {
let membuffer = membuffers.clone();
py.allow_threads(|| {
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level.clone(),
&self.llvm_options,
&f
);
let (registry, handles) =
WorkerRegistry::create_workers(threads, top_level.clone(), &self.llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
let mut generator = ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let mut generator =
ArtiqCodeGenerator::new("attributes_writeback".to_string(), size_t, self.time_fns);
let context = inkwell::context::Context::create();
let module = context.create_module("attributes_writeback");
let builder = context.create_builder();
let (_, module, _) = gen_func_impl(&context, &mut generator, &registry, builder, module,
attributes_writeback_task, |generator, ctx| {
let (_, module, _) = gen_func_impl(
&context,
&mut generator,
&registry,
builder,
module,
attributes_writeback_task,
|generator, ctx| {
attributes_writeback(ctx, generator, inner_resolver.as_ref(), &host_attributes)
}).unwrap();
},
)
.unwrap();
let buffer = module.write_bitcode_to_memory();
let buffer = buffer.as_slice().into();
membuffer.lock().push(buffer);
@ -636,11 +651,16 @@ impl Nac3 {
.create_module_from_ir(MemoryBuffer::create_from_memory_range(buffer, "main"))
.unwrap();
main.link_in_module(other)
.map_err(|err| CompileError::new_err(err.to_string()))?;
main.link_in_module(other).map_err(|err| CompileError::new_err(err.to_string()))?;
}
let builder = context.create_builder();
let modinit_return = main.get_function("__modinit__").unwrap().get_last_basic_block().unwrap().get_terminator().unwrap();
let modinit_return = main
.get_function("__modinit__")
.unwrap()
.get_last_basic_block()
.unwrap()
.get_terminator()
.unwrap();
builder.position_before(&modinit_return);
builder
.build_call(
@ -662,10 +682,7 @@ impl Nac3 {
}
// Demote all global variables that will not be referenced in the kernel to private
let preserved_symbols: Vec<&'static [u8]> = vec![
b"typeinfo",
b"now",
];
let preserved_symbols: Vec<&'static [u8]> = vec![b"typeinfo", b"now"];
let mut global_option = main.get_first_global();
while let Some(global) = global_option {
if !preserved_symbols.contains(&(global.get_name().to_bytes())) {
@ -674,7 +691,9 @@ impl Nac3 {
global_option = global.get_next_global();
}
let target_machine = self.llvm_options.target
let target_machine = self
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level)
.expect("couldn't create target machine");
@ -738,10 +757,7 @@ impl Nac3 {
}
}
fn link_with_lld(
elf_filename: String,
obj_filename: String,
) -> PyResult<()>{
fn link_with_lld(elf_filename: String, obj_filename: String) -> PyResult<()> {
let linker_args = vec![
"-shared".to_string(),
"--eh-frame-hdr".to_string(),
@ -760,9 +776,7 @@ fn link_with_lld(
return Err(CompileError::new_err("failed to start linker"));
}
} else {
return Err(CompileError::new_err(
"linker returned non-zero status code",
));
return Err(CompileError::new_err("linker returned non-zero status code"));
}
Ok(())
@ -772,7 +786,7 @@ fn add_exceptions(
composer: &mut TopLevelComposer,
builtin_def: &mut HashMap<StrRef, DefinitionId>,
builtin_ty: &mut HashMap<StrRef, Type>,
error_names: &[&str]
error_names: &[&str],
) -> Vec<Type> {
let mut types = Vec::new();
// note: this is only for builtin exceptions, i.e. the exception name is "0:{exn}"
@ -785,7 +799,7 @@ fn add_exceptions(
// constructor id
def_id + 1,
&mut composer.unifier,
&composer.primitives_ty
&composer.primitives_ty,
);
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_class)), None));
composer.definition_ast_list.push((Arc::new(RwLock::new(exception_fn)), None));
@ -834,7 +848,8 @@ impl Nac3 {
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_at_mu(ctx, arg);
Ok(None)
}))),
@ -852,7 +867,8 @@ impl Nac3 {
},
Arc::new(GenCall::new(Box::new(move |ctx, _, fun, args, generator| {
let arg_ty = fun.0.args[0].ty;
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
let arg =
args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty).unwrap();
time_fns.emit_delay_mu(ctx, arg);
Ok(None)
}))),
@ -867,8 +883,9 @@ impl Nac3 {
let types_mod = PyModule::import(py, "types").unwrap();
let get_id = |x: &PyAny| id_fn.call1((x,)).and_then(PyAny::extract).unwrap();
let get_attr_id = |obj: &PyModule, attr| id_fn.call1((obj.getattr(attr).unwrap(),))
.unwrap().extract().unwrap();
let get_attr_id = |obj: &PyModule, attr| {
id_fn.call1((obj.getattr(attr).unwrap(),)).unwrap().extract().unwrap()
};
let primitive_ids = PrimitivePythonId {
virtual_id: get_id(artiq_builtins.get_item("virtual").ok().flatten().unwrap()),
generic_alias: (
@ -877,7 +894,9 @@ impl Nac3 {
),
none: get_id(artiq_builtins.get_item("none").ok().flatten().unwrap()),
typevar: get_attr_id(typing_mod, "TypeVar"),
const_generic_marker: get_id(artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap()),
const_generic_marker: get_id(
artiq_builtins.get_item("_ConstGenericMarker").ok().flatten().unwrap(),
),
int: get_attr_id(builtins_mod, "int"),
int32: get_attr_id(numpy_mod, "int32"),
int64: get_attr_id(numpy_mod, "int64"),
@ -911,7 +930,7 @@ impl Nac3 {
llvm_options: CodeGenLLVMOptions {
opt_level: OptimizationLevel::Default,
target: Nac3::get_llvm_target_options(isa),
}
},
})
}
@ -952,7 +971,7 @@ impl Nac3 {
py: Python,
) -> PyResult<()> {
let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host {
let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned();
@ -961,7 +980,7 @@ impl Nac3 {
.expect("couldn't write module to file");
link_with_lld(
filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string()
working_directory.join("module.o").to_string_lossy().to_string(),
)?;
Ok(())
};
@ -997,7 +1016,7 @@ impl Nac3 {
py: Python,
) -> PyResult<PyObject> {
let target_machine = self.get_llvm_target_machine();
if self.isa == Isa::Host {
let link_fn = |module: &Module| {
let working_directory = self.working_directory.path().to_owned();
@ -1009,7 +1028,7 @@ impl Nac3 {
let filename = filename_path.to_str().unwrap();
link_with_lld(
filename.to_string(),
working_directory.join("module.o").to_string_lossy().to_string()
working_directory.join("module.o").to_string_lossy().to_string(),
)?;
Ok(PyBytes::new(py, &fs::read(filename).unwrap()).into())

View File

@ -3,10 +3,9 @@ use nac3core::{
codegen::{CodeGenContext, CodeGenerator},
symbol_resolver::{StaticValue, SymbolResolver, SymbolValue, ValueEnum},
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
helper::PrimDef,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef,
DefinitionId, TopLevelDef,
},
typecheck::{
type_inferencer::PrimitiveStore,
@ -22,9 +21,9 @@ use pyo3::{
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering::Relaxed},
Arc,
atomic::{AtomicBool, Ordering::Relaxed}
}
},
};
use crate::PrimitivePythonId;
@ -58,7 +57,7 @@ impl DeferredEvaluationStore {
}
}
/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the
/// A class field as stored in the [`InnerResolver`], represented by the ID and name of the
/// associated [`PythonValue`].
type ResolverField = (u64, StrRef);
/// A class field as stored in Python, represented by the `id()` and [`PyObject`] of the field.
@ -114,27 +113,27 @@ impl StaticValue for PythonValue {
ctx: &mut CodeGenContext<'ctx, '_>,
_: &mut dyn CodeGenerator,
) -> BasicValueEnum<'ctx> {
ctx.module
.get_global(format!("{}_const", self.id).as_str())
.map_or_else(
|| Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let global = ctx.module.add_global(
struct_type,
None,
format!("{}_const", self.id).as_str(),
);
global.set_constant(true);
global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(id as u64, false).into()],
false,
));
Ok(global.as_pointer_value().into())
})
.unwrap(),
|val| val.as_pointer_value().into(),
)
ctx.module.get_global(format!("{}_const", self.id).as_str()).map_or_else(
|| {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
let id: u32 = self.store_obj.call1(py, (self.value.clone(),))?.extract(py)?;
let struct_type = ctx.ctx.struct_type(&[ctx.ctx.i32_type().into()], false);
let global = ctx.module.add_global(
struct_type,
None,
format!("{}_const", self.id).as_str(),
);
global.set_constant(true);
global.set_initializer(&ctx.ctx.const_struct(
&[ctx.ctx.i32_type().const_int(id as u64, false).into()],
false,
));
Ok(global.as_pointer_value().into())
})
.unwrap()
},
|val| val.as_pointer_value().into(),
)
}
fn to_basic_value_enum<'ctx, 'a>(
@ -161,7 +160,8 @@ impl StaticValue for PythonValue {
self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.map(Option::unwrap)
}).map_err(|e| e.to_string())
})
.map_err(|e| e.to_string())
}
fn get_field<'ctx>(
@ -186,7 +186,7 @@ impl StaticValue for PythonValue {
Ok(None)
} else {
Ok(Some((id, obj)))
}
};
}
let def_id = { *self.resolver.pyid_to_def.read().get(&ty_id).unwrap() };
let mut mutable = true;
@ -264,9 +264,7 @@ impl InnerResolver {
.map(|elem| self.get_obj_type(py, elem, unifier, defs, primitives))??
{
Ok(t) => t,
Err(e) => {
return Ok(Err(format!("type error ({e}) at element #{i} of the list")))
}
Err(e) => return Ok(Err(format!("type error ({e}) at element #{i} of the list"))),
};
ty = match unifier.unify(ty, b) {
Ok(()) => ty,
@ -377,7 +375,7 @@ impl InnerResolver {
let constr_id: u64 = self.helper.id_fn.call1(py, (constr,))?.extract(py)?;
if constr_id == self.primitive_ids.const_generic_marker {
is_const_generic = true;
continue
continue;
}
if !is_const_generic && needs_defer {
@ -406,11 +404,11 @@ impl InnerResolver {
}
if !is_const_generic && needs_defer {
self.deferred_eval_store.store.write()
.push((result.clone(),
constraints.extract()?,
pyty.getattr("__name__")?.extract::<String>()?
));
self.deferred_eval_store.store.write().push((
result.clone(),
constraints.extract()?,
pyty.getattr("__name__")?.extract::<String>()?,
));
}
(result, is_const_generic)
@ -418,7 +416,10 @@ impl InnerResolver {
let res = if is_const_generic {
if constraint_types.len() != 1 {
return Ok(Err(format!("ConstGeneric expects 1 argument, got {}", constraint_types.len())))
return Ok(Err(format!(
"ConstGeneric expects 1 argument, got {}",
constraint_types.len()
)));
}
unifier.get_fresh_const_generic_var(constraint_types[0], Some(name.into()), None).0
@ -468,7 +469,7 @@ impl InnerResolver {
)));
}
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
if args.len() != 2 {
return Ok(Err(format!(
"type list needs exactly 2 type parameters, found {}",
@ -572,9 +573,7 @@ impl InnerResolver {
let str_fn =
pyo3::types::PyModule::import(py, "builtins").unwrap().getattr("repr").unwrap();
let str_repr: String = str_fn.call1((pyty,)).unwrap().extract().unwrap();
Ok(Err(format!(
"{str_repr} is not registered with NAC3 (@nac3 decorator missing?)"
)))
Ok(Err(format!("{str_repr} is not registered with NAC3 (@nac3 decorator missing?)")))
}
}
@ -589,31 +588,28 @@ impl InnerResolver {
let ty = self.helper.type_fn.call1(py, (obj,)).unwrap();
let py_obj_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if let Some(ty) = self.pyid_to_type.read().get(&py_obj_id) {
return Ok(Ok(*ty))
return Ok(Ok(*ty));
}
// check if constructor function exists in the methods list
let pyid_to_def = self.pyid_to_def.read();
let constructor_ty = pyid_to_def
.get(&py_obj_id)
.and_then(|def_id| {
defs
.iter()
.find_map(|def| {
if let TopLevelDef::Class {
object_id, methods, constructor, ..
} = &*def.read() {
if object_id == def_id && constructor.is_some() && methods.iter().any(|(s, _, _)| s == &"__init__".into()) {
return *constructor;
}
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
defs.iter().find_map(|def| {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
{
return *constructor;
}
None
})
});
}
None
})
});
if let Some(ty) = constructor_ty {
self.pyid_to_type.write().insert(py_obj_id, ty);
return Ok(Ok(ty))
return Ok(Ok(ty));
}
let (extracted_ty, inst_check) = match self.get_pyty_obj_type(
@ -664,7 +660,7 @@ impl InnerResolver {
}
}
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PrimDef::NDArray.id() => {
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 {
@ -680,12 +676,8 @@ impl InnerResolver {
match actual_ty {
Ok(t) => match unifier.unify(ty, t) {
Ok(()) => {
let ndarray_ty = make_ndarray_ty(
unifier,
primitives,
Some(ty),
Some(ndims),
);
let ndarray_ty =
make_ndarray_ty(unifier, primitives, Some(ty), Some(ndims));
Ok(Ok(ndarray_ty))
}
@ -726,7 +718,8 @@ impl InnerResolver {
let var_map = params
.iter()
.map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) else {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
else {
unreachable!()
};
@ -734,7 +727,7 @@ impl InnerResolver {
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
})
.collect::<VarMap>();
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()))
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()));
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
@ -754,8 +747,8 @@ impl InnerResolver {
let var_map = params
.iter()
.map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } =
&*unifier.get_ty(*ty) else {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty)
else {
unreachable!()
};
@ -767,25 +760,23 @@ impl InnerResolver {
// loop through non-function fields of the class to get the instantiated value
for field in fields {
let name: String = (*field.0).into();
if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1.0) {
if let TypeEnum::TFunc(..) = &*unifier.get_ty(field.1 .0) {
continue;
}
let field_data = match obj.getattr(name.as_str()) {
Ok(d) => d,
Err(e) => return Ok(Err(format!("{e}"))),
};
let ty = match self
.get_obj_type(py, field_data, unifier, defs, primitives)?
{
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of field `{name}` ({e})"
)))
}
};
let field_ty =
unifier.subst(field.1.0, &var_map).unwrap_or(field.1.0);
let ty =
match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t,
Err(e) => {
return Ok(Err(format!(
"error when getting type of field `{name}` ({e})"
)))
}
};
let field_ty = unifier.subst(field.1 .0, &var_map).unwrap_or(field.1 .0);
if let Err(e) = unifier.unify(ty, field_ty) {
// field type mismatch
return Ok(Err(format!(
@ -800,14 +791,15 @@ impl InnerResolver {
return Ok(Err("object is not of concrete type".into()));
}
}
let extracted_ty = unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
let extracted_ty =
unifier.subst(extracted_ty, &var_map).unwrap_or(extracted_ty);
Ok(Ok(extracted_ty))
};
let result = instantiate_obj();
// update/remove the cache according to the result
match result {
Ok(Ok(ty)) => self.pyid_to_type.write().insert(py_obj_id, ty),
_ => self.pyid_to_type.write().remove(&py_obj_id)
_ => self.pyid_to_type.write().remove(&py_obj_id),
};
result
}
@ -816,32 +808,32 @@ impl InnerResolver {
if unifier.unioned(extracted_ty, primitives.int32) {
obj.extract::<i32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int32"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.int64) {
obj.extract::<i64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of int64"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.uint32) {
obj.extract::<u32>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint32"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.uint64) {
obj.extract::<u64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of uint64"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.bool) {
obj.extract::<bool>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of bool"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else if unifier.unioned(extracted_ty, primitives.float) {
obj.extract::<f64>().map_or_else(
|_| Ok(Err(format!("{obj} is not in the range of float64"))),
|_| Ok(Ok(extracted_ty))
|_| Ok(Ok(extracted_ty)),
)
} else {
Ok(Ok(extracted_ty))
@ -893,8 +885,8 @@ impl InnerResolver {
}
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let elem_ty =
if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref()
let elem_ty = if let TypeEnum::TList { ty } =
ctx.unifier.get_ty_immutable(expected_ty).as_ref()
{
*ty
} else {
@ -918,13 +910,11 @@ impl InnerResolver {
let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| {
obj
.get_item(i)
.and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty)
.map_err(
|e| super::CompileError::new_err(
format!("Error getting element {i}: {e}"))
))
obj.get_item(i).and_then(|elem| {
self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
})
.collect();
let arr = arr?.unwrap();
@ -956,7 +946,10 @@ impl InnerResolver {
arr_global.set_initializer(&arr);
let val = arr_ty.const_named_struct(&[
arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::default())).into(),
arr_global
.as_pointer_value()
.const_cast(ty.ptr_type(AddressSpace::default()))
.into(),
size_t.const_int(len as u64, false).into(),
]);
@ -968,25 +961,21 @@ impl InnerResolver {
todo!()
} else if ty_id == self.primitive_ids.tuple {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else {
unreachable!()
};
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { unreachable!() };
let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?;
assert_eq!(elements.len(), tup_tys.len());
let val: Result<Option<Vec<_>>, _> =
elements
.iter()
.enumerate()
.zip(tup_tys)
.map(|((i, elem), ty)| self
.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e|
super::CompileError::new_err(
format!("Error getting element {i}: {e}")
)
)
).collect();
let val: Result<Option<Vec<_>>, _> = elements
.iter()
.enumerate()
.zip(tup_tys)
.map(|((i, elem), ty)| {
self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| {
super::CompileError::new_err(format!("Error getting element {i}: {e}"))
})
})
.collect();
let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false);
Ok(Some(val.into()))
@ -997,7 +986,7 @@ impl InnerResolver {
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type")
_ => unreachable!("must be option type"),
};
if id == self.primitive_ids.none {
// for option type, just a null ptr
@ -1009,7 +998,13 @@ impl InnerResolver {
))
} else {
match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty)
.get_obj_value(
py,
obj.getattr("_nac3_option").unwrap(),
ctx,
generator,
option_val_ty,
)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting value of Option object: {e}"
@ -1019,17 +1014,26 @@ impl InnerResolver {
let global_str = format!("{id}_option");
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&global_str).unwrap_or_else(|| {
ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str)
});
let global =
ctx.module.get_global(&global_str).unwrap_or_else(|| {
ctx.module.add_global(
v.get_type(),
Some(AddressSpace::default()),
&global_str,
)
});
return Ok(Some(global.as_pointer_value().into()));
}
self.global_value_ids.write().insert(id, obj.into());
}
let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::default()), &global_str);
let global = ctx.module.add_global(
v.get_type(),
Some(AddressSpace::default()),
&global_str,
);
global.set_initializer(&v);
Ok(Some(global.as_pointer_value().into()))
},
}
None => Ok(None),
}
}
@ -1066,8 +1070,16 @@ impl InnerResolver {
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty, _)| {
self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty)
.map_err(|e| super::CompileError::new_err(format!("Error getting field {name}: {e}")))
self.get_obj_value(
py,
obj.getattr(name.to_string().as_str())?,
ctx,
generator,
*ty,
)
.map_err(|e| {
super::CompileError::new_err(format!("Error getting field {name}: {e}"))
})
})
.collect();
let values = values?;
@ -1119,8 +1131,7 @@ impl InnerResolver {
if id == self.primitive_ids.none {
Ok(SymbolValue::OptionNone)
} else {
self
.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())?
self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())?
.map(|v| SymbolValue::OptionSome(Box::new(v)))
}
} else {
@ -1149,7 +1160,8 @@ impl SymbolResolver for Resolver {
}
}
Ok(sym_value)
}).unwrap()
})
.unwrap()
}
fn get_symbol_type(
@ -1166,7 +1178,7 @@ impl SymbolResolver for Resolver {
Ok(ty)
} else {
let Some(id) = self.0.name_to_pyid.get(&str) else {
return Err(format!("cannot find symbol `{str}`"))
return Err(format!("cannot find symbol `{str}`"));
};
let result = if let Some(t) = {
let pyid_to_type = self.0.pyid_to_type.read();
@ -1191,7 +1203,8 @@ impl SymbolResolver for Resolver {
}
}
Ok(sym_ty)
}).unwrap()
})
.unwrap()
};
result
}
@ -1242,15 +1255,16 @@ impl SymbolResolver for Resolver {
id_to_def.get(&id).copied().ok_or_else(String::new)
}
.or_else(|_| {
let py_id = self.0.name_to_pyid.get(&id)
.ok_or_else(|| HashSet::from([
format!("Undefined identifier `{id}`"),
]))?;
let result = self.0.pyid_to_def.read().get(py_id)
.copied()
.ok_or_else(|| HashSet::from([
format!("`{id}` is not registered with NAC3 (@nac3 decorator missing?)"),
]))?;
let py_id = self
.0
.name_to_pyid
.get(&id)
.ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))?;
let result = self.0.pyid_to_def.read().get(py_id).copied().ok_or_else(|| {
HashSet::from([format!(
"`{id}` is not registered with NAC3 (@nac3 decorator missing?)"
)])
})?;
self.0.id_to_def.write().insert(id, result);
Ok(result)
})
@ -1274,7 +1288,7 @@ impl SymbolResolver for Resolver {
&self,
unifier: &mut Unifier,
defs: &[Arc<RwLock<TopLevelDef>>],
primitives: &PrimitiveStore
primitives: &PrimitiveStore,
) -> Result<(), String> {
// we don't need a lock because this will only be run in a single thread
if self.0.deferred_eval_store.needs_defer.load(Relaxed) {
@ -1304,7 +1318,8 @@ impl SymbolResolver for Resolver {
}
}
Ok(Ok(()))
}).unwrap()?;
})
.unwrap()?;
}
Ok(())
}

View File

@ -1,10 +1,12 @@
use inkwell::{values::{BasicValueEnum, CallSiteValue}, AddressSpace, AtomicOrdering};
use inkwell::{
values::{BasicValueEnum, CallSiteValue},
AddressSpace, AtomicOrdering,
};
use itertools::Either;
use nac3core::codegen::CodeGenContext;
/// Functions for manipulating the timeline.
pub trait TimeFns {
/// Emits LLVM IR for `now_mu`.
fn emit_now_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> BasicValueEnum<'ctx>;
@ -27,26 +29,31 @@ impl TimeFns for NowPinningTimeFns64 {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
}
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder
.build_left_shift(zext_hi, i64_type.const_int(32, false), "")
.unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").map(Into::into).unwrap()
}
@ -58,7 +65,8 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_32 = i64_type.const_int(32, false);
let time = t.into_int_value();
let time_hi = ctx.builder
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
@ -70,14 +78,16 @@ impl TimeFns for NowPinningTimeFns64 {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
@ -90,50 +100,49 @@ impl TimeFns for NowPinningTimeFns64 {
.unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let i64_type = ctx.ctx.i64_type();
let i32_type = ctx.ctx.i32_type();
let now = ctx
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
}.unwrap();
}
.unwrap();
let now_hi = ctx.builder.build_load(now_hiptr, "now.hi")
let now_hi = ctx
.builder
.build_load(now_hiptr, "now.hi")
.map(BasicValueEnum::into_int_value)
.unwrap();
let now_lo = ctx.builder.build_load(now_loptr, "now.lo")
let now_lo = ctx
.builder
.build_load(now_loptr, "now.lo")
.map(BasicValueEnum::into_int_value)
.unwrap();
let dt = dt.into_int_value();
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "").unwrap();
let shifted_hi = ctx.builder
.build_left_shift(zext_hi, i64_type.const_int(32, false), "")
.unwrap();
let shifted_hi =
ctx.builder.build_left_shift(zext_hi, i64_type.const_int(32, false), "").unwrap();
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "").unwrap();
let now_val = ctx.builder.build_or(shifted_hi, zext_lo, "now").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(
time,
i64_type.const_int(32, false),
false,
"",
).unwrap(),
ctx.builder
.build_right_shift(time, i64_type.const_int(32, false), false, "")
.unwrap(),
i32_type,
"time.hi",
)
@ -164,16 +173,16 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now")
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "now")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo").unwrap();
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
ctx.builder.build_or(now_lo, now_hi, "now_mu")
.map(Into::into)
.unwrap()
ctx.builder.build_or(now_lo, now_hi, "now_mu").map(Into::into).unwrap()
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -183,7 +192,8 @@ impl TimeFns for NowPinningTimeFns {
let time = t.into_int_value();
let time_hi = ctx.builder
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "").unwrap(),
i32_type,
@ -195,14 +205,16 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_hiptr = ctx.builder
let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap();
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
@ -215,11 +227,7 @@ impl TimeFns for NowPinningTimeFns {
.unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let i32_type = ctx.ctx.i32_type();
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
@ -227,7 +235,8 @@ impl TimeFns for NowPinningTimeFns {
.module
.get_global("now")
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder
let now_raw = ctx
.builder
.build_load(now.as_pointer_value(), "")
.map(BasicValueEnum::into_int_value)
.unwrap();
@ -238,7 +247,8 @@ impl TimeFns for NowPinningTimeFns {
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi").unwrap();
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val").unwrap();
let time = ctx.builder.build_int_add(now_val, dt, "time").unwrap();
let time_hi = ctx.builder
let time_hi = ctx
.builder
.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi").unwrap(),
i32_type,
@ -246,14 +256,16 @@ impl TimeFns for NowPinningTimeFns {
)
.unwrap();
let time_lo = ctx.builder.build_int_truncate(time, i32_type, "time.lo").unwrap();
let now_hiptr = ctx.builder
let now_hiptr = ctx
.builder
.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr")
.map(BasicValueEnum::into_pointer_value)
.unwrap();
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
}.unwrap();
}
.unwrap();
ctx.builder
.build_store(now_hiptr, time_hi)
.unwrap()
@ -276,7 +288,8 @@ impl TimeFns for ExternTimeFns {
let now_mu = ctx.module.get_function("now_mu").unwrap_or_else(|| {
ctx.module.add_function("now_mu", ctx.ctx.i64_type().fn_type(&[], false), None)
});
ctx.builder.build_call(now_mu, &[], "now_mu")
ctx.builder
.build_call(now_mu, &[], "now_mu")
.map(CallSiteValue::try_as_basic_value)
.map(Either::unwrap_left)
.unwrap()
@ -293,11 +306,7 @@ impl TimeFns for ExternTimeFns {
ctx.builder.build_call(at_mu, &[t.into()], "at_mu").unwrap();
}
fn emit_delay_mu<'ctx>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
dt: BasicValueEnum<'ctx>,
) {
fn emit_delay_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, dt: BasicValueEnum<'ctx>) {
let delay_mu = ctx.module.get_function("delay_mu").unwrap_or_else(|| {
ctx.module.add_function(
"delay_mu",

File diff suppressed because it is too large Load Diff

View File

@ -85,33 +85,22 @@ impl<U> crate::fold::Fold<U> for ConstantOptimizer {
fn fold_expr(&mut self, node: crate::Expr<U>) -> Result<crate::Expr<U>, Self::Error> {
match node.node {
crate::ExprKind::Tuple { elts, ctx } => {
let elts = elts
.into_iter()
.map(|x| self.fold_expr(x))
.collect::<Result<Vec<_>, _>>()?;
let expr = if elts
.iter()
.all(|e| matches!(e.node, crate::ExprKind::Constant { .. }))
{
let tuple = elts
.into_iter()
.map(|e| match e.node {
crate::ExprKind::Constant { value, .. } => value,
_ => unreachable!(),
})
.collect();
crate::ExprKind::Constant {
value: Constant::Tuple(tuple),
kind: None,
}
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr {
node: expr,
custom: node.custom,
location: node.location,
})
let elts =
elts.into_iter().map(|x| self.fold_expr(x)).collect::<Result<Vec<_>, _>>()?;
let expr =
if elts.iter().all(|e| matches!(e.node, crate::ExprKind::Constant { .. })) {
let tuple = elts
.into_iter()
.map(|e| match e.node {
crate::ExprKind::Constant { value, .. } => value,
_ => unreachable!(),
})
.collect();
crate::ExprKind::Constant { value: Constant::Tuple(tuple), kind: None }
} else {
crate::ExprKind::Tuple { elts, ctx }
};
Ok(crate::Expr { node: expr, custom: node.custom, location: node.location })
}
_ => crate::fold::fold_expr(self, node),
}
@ -138,18 +127,12 @@ mod tests {
Located {
location,
custom,
node: ExprKind::Constant {
value: 1.into(),
kind: None,
},
node: ExprKind::Constant { value: 1.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 2.into(),
kind: None,
},
node: ExprKind::Constant { value: 2.into(), kind: None },
},
Located {
location,
@ -160,26 +143,17 @@ mod tests {
Located {
location,
custom,
node: ExprKind::Constant {
value: 3.into(),
kind: None,
},
node: ExprKind::Constant { value: 3.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 4.into(),
kind: None,
},
node: ExprKind::Constant { value: 4.into(), kind: None },
},
Located {
location,
custom,
node: ExprKind::Constant {
value: 5.into(),
kind: None,
},
node: ExprKind::Constant { value: 5.into(), kind: None },
},
],
},
@ -187,9 +161,7 @@ mod tests {
],
},
};
let new_ast = ConstantOptimizer::new()
.fold_expr(ast)
.unwrap_or_else(|e| match e {});
let new_ast = ConstantOptimizer::new().fold_expr(ast).unwrap_or_else(|e| match e {});
assert_eq!(
new_ast,
Located {
@ -199,11 +171,7 @@ mod tests {
value: Constant::Tuple(vec![
1.into(),
2.into(),
Constant::Tuple(vec![
3.into(),
4.into(),
5.into(),
])
Constant::Tuple(vec![3.into(), 4.into(), 5.into(),])
]),
kind: None
},

View File

@ -64,11 +64,4 @@ macro_rules! simple_fold {
};
}
simple_fold!(
usize,
String,
bool,
StrRef,
constant::Constant,
constant::ConversionFlag
);
simple_fold!(usize, String, bool, StrRef, constant::Constant, constant::ConversionFlag);

View File

@ -34,10 +34,7 @@ impl<U> ExprKind<U> {
ExprKind::Starred { .. } => "starred",
ExprKind::Slice { .. } => "slice",
ExprKind::JoinedStr { values } => {
if values
.iter()
.any(|e| matches!(e.node, ExprKind::JoinedStr { .. }))
{
if values.iter().any(|e| matches!(e.node, ExprKind::JoinedStr { .. })) {
"f-string expression"
} else {
"literal"

View File

@ -9,6 +9,6 @@ mod impls;
mod location;
pub use ast_gen::*;
pub use location::{Location, FileName};
pub use location::{FileName, Location};
pub type Suite<U = ()> = Vec<Stmt<U>>;

View File

@ -1,6 +1,6 @@
//! Datatypes to support source location information.
use std::cmp::Ordering;
use crate::ast_gen::StrRef;
use std::cmp::Ordering;
use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@ -22,7 +22,7 @@ impl From<String> for FileName {
pub struct Location {
pub row: usize,
pub column: usize,
pub file: FileName
pub file: FileName,
}
impl fmt::Display for Location {
@ -35,12 +35,12 @@ impl Ord for Location {
fn cmp(&self, other: &Self) -> Ordering {
let file_cmp = self.file.0.to_string().cmp(&other.file.0.to_string());
if file_cmp != Ordering::Equal {
return file_cmp
return file_cmp;
}
let row_cmp = self.row.cmp(&other.row);
if row_cmp != Ordering::Equal {
return row_cmp
return row_cmp;
}
self.column.cmp(&other.column)
@ -76,11 +76,7 @@ impl Location {
)
}
}
Visualize {
loc: *self,
line,
desc,
}
Visualize { loc: *self, line, desc }
}
}

View File

@ -11,6 +11,8 @@ indexmap = "2.2"
parking_lot = "0.12"
rayon = "1.8"
nac3parser = { path = "../nac3parser" }
strum = "0.26.2"
strum_macros = "0.26.4"
[dependencies.inkwell]
version = "0.4"

File diff suppressed because it is too large Load Diff

View File

@ -1,29 +1,28 @@
use inkwell::{
AddressSpace, IntPredicate,
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
use crate::codegen::{
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
CodeGenContext, CodeGenerator,
};
use inkwell::context::Context;
use inkwell::types::{ArrayType, BasicType, StructType};
use inkwell::values::{ArrayValue, BasicValue, StructValue};
use crate::codegen::{
CodeGenContext,
CodeGenerator,
irrt::{call_ndarray_calc_size, call_ndarray_flatten_index},
llvm_intrinsics::call_int_umin,
stmt::gen_for_callback_incrementing,
use inkwell::{
types::{AnyTypeEnum, BasicTypeEnum, IntType, PointerType},
values::{BasicValueEnum, IntValue, PointerValue},
AddressSpace, IntPredicate,
};
/// A LLVM type that is used to represent a non-primitive type in NAC3.
pub trait ProxyType<'ctx>: Into<Self::Base> {
/// The LLVM type of which values of this type possess. This is usually a
/// The LLVM type of which values of this type possess. This is usually a
/// [LLVM pointer type][PointerType].
type Base: BasicType<'ctx>;
/// The underlying LLVM type used to represent values. This is usually the element type of
/// The underlying LLVM type used to represent values. This is usually the element type of
/// [`Base`] if it is a pointer, otherwise this is the same type as `Base`.
type Underlying: BasicType<'ctx>;
/// The type of values represented by this type.
type Value: ProxyValue<'ctx>;
@ -64,7 +63,7 @@ pub trait ProxyType<'ctx>: Into<Self::Base> {
/// A LLVM type that is used to represent a non-primitive value in NAC3.
pub trait ProxyValue<'ctx>: Into<Self::Base> {
/// The type of LLVM values represented by this instance. This is usually the
/// The type of LLVM values represented by this instance. This is usually the
/// [LLVM pointer type][PointerValue].
type Base: BasicValue<'ctx>;
@ -81,7 +80,7 @@ pub trait ProxyValue<'ctx>: Into<Self::Base> {
/// Returns the [base value][Self::Base] of this proxy.
fn as_base_value(&self) -> Self::Base;
/// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a
/// Loads this value into its [underlying representation][Self::Underlying]. Usually involves a
/// `getelementptr` if [`Self::Base`] is a [pointer value][PointerValue].
fn as_underlying_value(
&self,
@ -152,7 +151,9 @@ pub trait ArrayLikeIndexer<'ctx, Index = IntValue<'ctx>>: ArrayLikeValue<'ctx> {
}
/// An array-like value that can have its array elements accessed as a [`BasicValueEnum`].
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexer<'ctx, Index> {
pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
@ -181,7 +182,9 @@ pub trait UntypedArrayLikeAccessor<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndex
}
/// An array-like value that can have its array elements mutated as a [`BasicValueEnum`].
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexer<'ctx, Index> {
pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>:
ArrayLikeIndexer<'ctx, Index>
{
/// # Safety
///
/// This function should be called with a valid index.
@ -210,9 +213,15 @@ pub trait UntypedArrayLikeMutator<'ctx, Index = IntValue<'ctx>>: ArrayLikeIndexe
}
/// An array-like value that can have its array elements accessed as an arbitrary type `T`.
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLikeAccessor<'ctx, Index> {
pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeAccessor<'ctx, Index>
{
/// Casts an element from [`BasicValueEnum`] into `T`.
fn downcast_to_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) -> T;
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T;
/// # Safety
///
@ -242,9 +251,15 @@ pub trait TypedArrayLikeAccessor<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayL
}
/// An array-like value that can have its array elements mutated as an arbitrary type `T`.
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLikeMutator<'ctx, Index> {
pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>:
UntypedArrayLikeMutator<'ctx, Index>
{
/// Casts an element from T into [`BasicValueEnum`].
fn upcast_from_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: T) -> BasicValueEnum<'ctx>;
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx>;
/// # Safety
///
@ -274,7 +289,8 @@ pub trait TypedArrayLikeMutator<'ctx, T, Index = IntValue<'ctx>>: UntypedArrayLi
}
/// Type alias for a function that casts a [`BasicValueEnum`] into a `T`.
type ValueDowncastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>;
type ValueDowncastFn<'ctx, T> =
Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> T>;
/// Type alias for a function that casts a `T` into a [`BasicValueEnum`].
type ValueUpcastFn<'ctx, T> = Box<dyn Fn(&mut CodeGenContext<'ctx, '_>, T) -> BasicValueEnum<'ctx>>;
@ -286,7 +302,9 @@ pub struct TypedArrayLikeAdapter<'ctx, T, Adapted: ArrayLikeValue<'ctx> = ArrayS
}
impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: ArrayLikeValue<'ctx> {
where
Adapted: ArrayLikeValue<'ctx>,
{
/// Creates a [`TypedArrayLikeAdapter`].
///
/// * `adapted` - The value to be adapted.
@ -302,7 +320,9 @@ impl<'ctx, T, Adapted> TypedArrayLikeAdapter<'ctx, T, Adapted>
}
impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: ArrayLikeValue<'ctx> {
where
Adapted: ArrayLikeValue<'ctx>,
{
fn element_type<G: CodeGenerator + ?Sized>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
@ -328,8 +348,11 @@ impl<'ctx, T, Adapted> ArrayLikeValue<'ctx> for TypedArrayLikeAdapter<'ctx, T, A
}
}
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: ArrayLikeIndexer<'ctx, Index> {
impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: ArrayLikeIndexer<'ctx, Index>,
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -351,21 +374,43 @@ impl<'ctx, T, Index, Adapted> ArrayLikeIndexer<'ctx, Index> for TypedArrayLikeAd
}
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: UntypedArrayLikeAccessor<'ctx, Index> {}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: UntypedArrayLikeMutator<'ctx, Index> {}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeAccessor<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> UntypedArrayLikeMutator<'ctx, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: UntypedArrayLikeAccessor<'ctx, Index> {
fn downcast_to_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: BasicValueEnum<'ctx>) -> T {
impl<'ctx, T, Index, Adapted> TypedArrayLikeAccessor<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeAccessor<'ctx, Index>,
{
fn downcast_to_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: BasicValueEnum<'ctx>,
) -> T {
(self.downcast_fn)(ctx, value)
}
}
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index> for TypedArrayLikeAdapter<'ctx, T, Adapted>
where Adapted: UntypedArrayLikeMutator<'ctx, Index> {
fn upcast_from_type(&self, ctx: &mut CodeGenContext<'ctx, '_>, value: T) -> BasicValueEnum<'ctx> {
impl<'ctx, T, Index, Adapted> TypedArrayLikeMutator<'ctx, T, Index>
for TypedArrayLikeAdapter<'ctx, T, Adapted>
where
Adapted: UntypedArrayLikeMutator<'ctx, Index>,
{
fn upcast_from_type(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
value: T,
) -> BasicValueEnum<'ctx> {
(self.upcast_fn)(ctx, value)
}
}
@ -427,15 +472,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[*idx],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -458,9 +499,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ArraySliceValue<'ctx> {
ctx.current_loc,
);
unsafe {
self.ptr_offset_unchecked(ctx, generator, idx, name)
}
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
@ -476,31 +515,33 @@ pub struct ListType<'ctx> {
impl<'ctx> ListType<'ctx> {
/// Checks whether `llvm_ty` represents a `list` type, returning [Err] if it does not.
pub fn is_type(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
let llvm_list_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_list_ty) = llvm_list_ty else {
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"))
return Err(format!("Expected struct type for `list` type, got {llvm_list_ty}"));
};
if llvm_list_ty.count_fields() != 2 {
return Err(format!("Expected 2 fields in `list`, got {}", llvm_list_ty.count_fields()))
return Err(format!(
"Expected 2 fields in `list`, got {}",
llvm_list_ty.count_fields()
));
}
let list_size_ty = llvm_list_ty.get_field_type_at_index(0).unwrap();
let Ok(_) = PointerType::try_from(list_size_ty) else {
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"))
return Err(format!("Expected pointer type for `list.0`, got {list_size_ty}"));
};
let list_data_ty = llvm_list_ty.get_field_type_at_index(1).unwrap();
let Ok(list_data_ty) = IntType::try_from(list_data_ty) else {
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"))
return Err(format!("Expected int type for `list.1`, got {list_data_ty}"));
};
if list_data_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!("Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()))
return Err(format!(
"Expected {}-bit int type for `list.1`, got {}-bit int",
llvm_usize.get_bit_width(),
list_data_ty.get_bit_width()
));
}
Ok(())
@ -516,10 +557,7 @@ impl<'ctx> ListType<'ctx> {
let llvm_usize = generator.get_size_type(ctx);
let llvm_list = ctx
.struct_type(
&[
element_type.ptr_type(AddressSpace::default()).into(),
llvm_usize.into(),
],
&[element_type.ptr_type(AddressSpace::default()).into(), llvm_usize.into()],
false,
)
.ptr_type(AddressSpace::default());
@ -555,7 +593,7 @@ impl<'ctx> ListType<'ctx> {
.get_field_type_at_index(0)
.map(BasicTypeEnum::into_pointer_type)
.map(PointerType::get_element_type)
.unwrap()
.unwrap()
}
}
@ -612,16 +650,17 @@ pub struct ListValue<'ctx> {
impl<'ctx> ListValue<'ctx> {
/// Checks whether `value` is an instance of `list`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
ListType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`ListValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
@ -635,11 +674,13 @@ impl<'ctx> ListValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
@ -649,11 +690,13 @@ impl<'ctx> ListValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.size.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
@ -704,7 +747,8 @@ impl<'ctx> ListValue<'ctx> {
.or_else(|| self.name.map(|v| format!("{v}.size")))
.unwrap_or_default();
ctx.builder.build_load(psize, var_name.as_str())
ctx.builder
.build_load(psize, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
@ -761,7 +805,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for ListDataProxy<'ctx, '_> {
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
ctx.builder
.build_load(self.0.pptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
@ -783,15 +828,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[*idx],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -814,9 +855,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for ListDataProxy<'ctx, '_> {
ctx.current_loc,
);
unsafe {
self.ptr_offset_unchecked(ctx, generator, idx, name)
}
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
@ -834,19 +873,26 @@ impl<'ctx> RangeType<'ctx> {
pub fn is_type(llvm_ty: PointerType<'ctx>) -> Result<(), String> {
let llvm_range_ty = llvm_ty.get_element_type();
let AnyTypeEnum::ArrayType(llvm_range_ty) = llvm_range_ty else {
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"))
return Err(format!("Expected array type for `range` type, got {llvm_range_ty}"));
};
if llvm_range_ty.len() != 3 {
return Err(format!("Expected 3 elements for `range` type, got {}", llvm_range_ty.len()))
return Err(format!(
"Expected 3 elements for `range` type, got {}",
llvm_range_ty.len()
));
}
let llvm_range_elem_ty = llvm_range_ty.get_element_type();
let Ok(llvm_range_elem_ty) = IntType::try_from(llvm_range_elem_ty) else {
return Err(format!("Expected int type for `range` element type, got {llvm_range_elem_ty}"))
return Err(format!(
"Expected int type for `range` element type, got {llvm_range_elem_ty}"
));
};
if llvm_range_elem_ty.get_bit_width() != 32 {
return Err(format!("Expected 32-bit int type for `range` element type, got {}",
llvm_range_elem_ty.get_bit_width()))
return Err(format!(
"Expected 32-bit int type for `range` element type, got {}",
llvm_range_elem_ty.get_bit_width()
));
}
Ok(())
@ -872,11 +918,7 @@ impl<'ctx> RangeType<'ctx> {
/// Returns the type of all fields of this `range` type.
#[must_use]
pub fn value_type(&self) -> IntType<'ctx> {
self.as_base_type()
.get_element_type()
.into_array_type()
.get_element_type()
.into_int_type()
self.as_base_type().get_element_type().into_array_type().get_element_type().into_int_type()
}
}
@ -897,7 +939,11 @@ impl<'ctx> ProxyType<'ctx> for RangeType<'ctx> {
)
}
fn create_value(&self, value: <Self::Value as ProxyValue<'ctx>>::Base, name: Option<&'ctx str>) -> Self::Value {
fn create_value(
&self,
value: <Self::Value as ProxyValue<'ctx>>::Base,
name: Option<&'ctx str>,
) -> Self::Value {
debug_assert_eq!(value.get_type(), self.as_base_type());
RangeValue { value, name }
@ -944,11 +990,13 @@ impl<'ctx> RangeValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.start.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(0, false)],
var_name.as_str(),
)
.unwrap()
}
}
@ -957,11 +1005,13 @@ impl<'ctx> RangeValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.end.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
var_name.as_str(),
)
.unwrap()
}
}
@ -970,20 +1020,18 @@ impl<'ctx> RangeValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.step.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, false)],
var_name.as_str(),
)
.unwrap()
}
}
/// Stores the `start` value into this instance.
pub fn store_start(
&self,
ctx: &CodeGenContext<'ctx, '_>,
start: IntValue<'ctx>,
) {
pub fn store_start(&self, ctx: &CodeGenContext<'ctx, '_>, start: IntValue<'ctx>) {
debug_assert_eq!(start.get_type().get_bit_width(), 32);
let pstart = self.ptr_to_start(ctx);
@ -998,17 +1046,14 @@ impl<'ctx> RangeValue<'ctx> {
.or_else(|| self.name.map(|v| format!("{v}.start")))
.unwrap_or_default();
ctx.builder.build_load(pstart, var_name.as_str())
ctx.builder
.build_load(pstart, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
/// Stores the `end` value into this instance.
pub fn store_end(
&self,
ctx: &CodeGenContext<'ctx, '_>,
end: IntValue<'ctx>,
) {
pub fn store_end(&self, ctx: &CodeGenContext<'ctx, '_>, end: IntValue<'ctx>) {
debug_assert_eq!(end.get_type().get_bit_width(), 32);
let pend = self.ptr_to_end(ctx);
@ -1023,17 +1068,11 @@ impl<'ctx> RangeValue<'ctx> {
.or_else(|| self.name.map(|v| format!("{v}.end")))
.unwrap_or_default();
ctx.builder.build_load(pend, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
ctx.builder.build_load(pend, var_name.as_str()).map(BasicValueEnum::into_int_value).unwrap()
}
/// Stores the `step` value into this instance.
pub fn store_step(
&self,
ctx: &CodeGenContext<'ctx, '_>,
step: IntValue<'ctx>,
) {
pub fn store_step(&self, ctx: &CodeGenContext<'ctx, '_>, step: IntValue<'ctx>) {
debug_assert_eq!(step.get_type().get_bit_width(), 32);
let pstep = self.ptr_to_step(ctx);
@ -1048,7 +1087,8 @@ impl<'ctx> RangeValue<'ctx> {
.or_else(|| self.name.map(|v| format!("{v}.step")))
.unwrap_or_default();
ctx.builder.build_load(pstep, var_name.as_str())
ctx.builder
.build_load(pstep, var_name.as_str())
.map(BasicValueEnum::into_int_value)
.unwrap()
}
@ -1094,45 +1134,51 @@ pub struct NDArrayType<'ctx> {
impl<'ctx> NDArrayType<'ctx> {
/// Checks whether `llvm_ty` represents a `ndarray` type, returning [Err] if it does not.
pub fn is_type(
llvm_ty: PointerType<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
pub fn is_type(llvm_ty: PointerType<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
let llvm_ndarray_ty = llvm_ty.get_element_type();
let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else {
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"))
return Err(format!("Expected struct type for `NDArray` type, got {llvm_ndarray_ty}"));
};
if llvm_ndarray_ty.count_fields() != 3 {
return Err(format!("Expected 3 fields in `NDArray`, got {}", llvm_ndarray_ty.count_fields()))
return Err(format!(
"Expected 3 fields in `NDArray`, got {}",
llvm_ndarray_ty.count_fields()
));
}
let ndarray_ndims_ty = llvm_ndarray_ty.get_field_type_at_index(0).unwrap();
let Ok(ndarray_ndims_ty) = IntType::try_from(ndarray_ndims_ty) else {
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"))
return Err(format!("Expected int type for `ndarray.0`, got {ndarray_ndims_ty}"));
};
if ndarray_ndims_ty.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!("Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()))
return Err(format!(
"Expected {}-bit int type for `ndarray.0`, got {}-bit int",
llvm_usize.get_bit_width(),
ndarray_ndims_ty.get_bit_width()
));
}
let ndarray_dims_ty = llvm_ndarray_ty.get_field_type_at_index(1).unwrap();
let Ok(ndarray_pdims) = PointerType::try_from(ndarray_dims_ty) else {
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"))
return Err(format!("Expected pointer type for `ndarray.1`, got {ndarray_dims_ty}"));
};
let ndarray_dims = ndarray_pdims.get_element_type();
let Ok(ndarray_dims) = IntType::try_from(ndarray_dims) else {
return Err(format!("Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"))
return Err(format!(
"Expected pointer-to-int type for `ndarray.1`, got pointer-to-{ndarray_dims}"
));
};
if ndarray_dims.get_bit_width() != llvm_usize.get_bit_width() {
return Err(format!("Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()))
return Err(format!(
"Expected pointer-to-{}-bit int type for `ndarray.1`, got pointer-to-{}-bit int",
llvm_usize.get_bit_width(),
ndarray_dims.get_bit_width()
));
}
let ndarray_data_ty = llvm_ndarray_ty.get_field_type_at_index(2).unwrap();
let Ok(_) = PointerType::try_from(ndarray_data_ty) else {
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"))
return Err(format!("Expected pointer type for `ndarray.2`, got {ndarray_data_ty}"));
};
Ok(())
@ -1151,13 +1197,13 @@ impl<'ctx> NDArrayType<'ctx> {
//
// * num_dims: Number of dimensions in the array
// * dims: Pointer to an array containing the size of each dimension
// * data: Pointer to an array containing the array data
// * data: Pointer to an array containing the array data
let llvm_ndarray = ctx
.struct_type(
&[
llvm_usize.into(),
llvm_usize.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
dtype.ptr_type(AddressSpace::default()).into(),
],
false,
)
@ -1193,7 +1239,7 @@ impl<'ctx> NDArrayType<'ctx> {
.into_struct_type()
.get_field_type_at_index(2)
.unwrap()
}
}
}
impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
@ -1208,9 +1254,7 @@ impl<'ctx> ProxyType<'ctx> for NDArrayType<'ctx> {
name: Option<&'ctx str>,
) -> Self::Value {
self.create_value(
generator
.gen_var_alloc(ctx, self.as_underlying_type().into(), name)
.unwrap(),
generator.gen_var_alloc(ctx, self.as_underlying_type().into(), name).unwrap(),
name,
)
}
@ -1251,16 +1295,17 @@ pub struct NDArrayValue<'ctx> {
impl<'ctx> NDArrayValue<'ctx> {
/// Checks whether `value` is an instance of `NDArray`, returning [Err] if `value` is not an
/// instance.
pub fn is_instance(
value: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
) -> Result<(), String> {
pub fn is_instance(value: PointerValue<'ctx>, llvm_usize: IntType<'ctx>) -> Result<(), String> {
NDArrayType::is_type(value.get_type(), llvm_usize)
}
/// Creates an [`NDArrayValue`] from a [`PointerValue`].
#[must_use]
pub fn from_ptr_val(ptr: PointerValue<'ctx>, llvm_usize: IntType<'ctx>, name: Option<&'ctx str>) -> Self {
pub fn from_ptr_val(
ptr: PointerValue<'ctx>,
llvm_usize: IntType<'ctx>,
name: Option<&'ctx str>,
) -> Self {
debug_assert!(Self::is_instance(ptr, llvm_usize).is_ok());
<Self as ProxyValue<'ctx>>::Type::from_type(ptr.get_type(), llvm_usize)
@ -1273,11 +1318,13 @@ impl<'ctx> NDArrayValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.ndims.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
var_name.as_str(),
)
.unwrap()
}
}
@ -1297,9 +1344,7 @@ impl<'ctx> NDArrayValue<'ctx> {
/// Returns the number of dimensions of this `NDArray` as a value.
pub fn load_ndims(&self, ctx: &CodeGenContext<'ctx, '_>) -> IntValue<'ctx> {
let pndims = self.ptr_to_ndims(ctx);
ctx.builder.build_load(pndims, "")
.map(BasicValueEnum::into_int_value)
.unwrap()
ctx.builder.build_load(pndims, "").map(BasicValueEnum::into_int_value).unwrap()
}
/// Returns the double-indirection pointer to the `dims` array, as if by calling `getelementptr`
@ -1309,11 +1354,13 @@ impl<'ctx> NDArrayValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.dims.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(1, true)],
var_name.as_str(),
)
.unwrap()
}
}
@ -1345,11 +1392,13 @@ impl<'ctx> NDArrayValue<'ctx> {
let var_name = self.name.map(|v| format!("{v}.data.addr")).unwrap_or_default();
unsafe {
ctx.builder.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.as_base_value(),
&[llvm_i32.const_zero(), llvm_i32.const_int(2, true)],
var_name.as_str(),
)
.unwrap()
}
}
@ -1427,7 +1476,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDimsProxy<'ctx, '_> {
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
ctx.builder
.build_load(self.0.ptr_to_dims(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
@ -1449,15 +1499,11 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
let var_name = name
.map(|v| format!("{v}.addr"))
.unwrap_or_default();
let var_name = name.map(|v| format!("{v}.addr")).unwrap_or_default();
ctx.builder.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[*idx],
var_name.as_str(),
).unwrap()
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], var_name.as_str())
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -1468,12 +1514,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
name: Option<&str>,
) -> PointerValue<'ctx> {
let size = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(
IntPredicate::ULT,
*idx,
size,
""
).unwrap();
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, size, "").unwrap();
ctx.make_assert(
generator,
in_range,
@ -1483,9 +1524,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx, IntValue<'ctx>> for NDArrayDimsProxy<'ctx, '_>
ctx.current_loc,
);
unsafe {
self.ptr_offset_unchecked(ctx, generator, idx, name)
}
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
@ -1532,7 +1571,8 @@ impl<'ctx> ArrayLikeValue<'ctx> for NDArrayDataProxy<'ctx, '_> {
) -> PointerValue<'ctx> {
let var_name = self.0.name.map(|v| format!("{v}.data")).unwrap_or_default();
ctx.builder.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
ctx.builder
.build_load(self.0.ptr_to_data(ctx), var_name.as_str())
.map(BasicValueEnum::into_pointer_value)
.unwrap()
}
@ -1554,11 +1594,9 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
idx: &IntValue<'ctx>,
name: Option<&str>,
) -> PointerValue<'ctx> {
ctx.builder.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[*idx],
name.unwrap_or_default(),
).unwrap()
ctx.builder
.build_in_bounds_gep(self.base_ptr(ctx, generator), &[*idx], name.unwrap_or_default())
.unwrap()
}
fn ptr_offset<G: CodeGenerator + ?Sized>(
@ -1569,12 +1607,7 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
name: Option<&str>,
) -> PointerValue<'ctx> {
let data_sz = self.size(ctx, generator);
let in_range = ctx.builder.build_int_compare(
IntPredicate::ULT,
*idx,
data_sz,
""
).unwrap();
let in_range = ctx.builder.build_int_compare(IntPredicate::ULT, *idx, data_sz, "").unwrap();
ctx.make_assert(
generator,
in_range,
@ -1584,16 +1617,16 @@ impl<'ctx> ArrayLikeIndexer<'ctx> for NDArrayDataProxy<'ctx, '_> {
ctx.current_loc,
);
unsafe {
self.ptr_offset_unchecked(ctx, generator, idx, name)
}
unsafe { self.ptr_offset_unchecked(ctx, generator, idx, name) }
}
}
impl<'ctx> UntypedArrayLikeAccessor<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx> UntypedArrayLikeMutator<'ctx, IntValue<'ctx>> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index> for NDArrayDataProxy<'ctx, '_> {
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
unsafe fn ptr_offset_unchecked<G: CodeGenerator + ?Sized>(
&self,
ctx: &mut CodeGenContext<'ctx, '_>,
@ -1610,21 +1643,23 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
let Ok(indices_elem_ty) = IntType::try_from(indices_elem_ty) else {
panic!("Expected list[int32] but got {indices_elem_ty}")
};
assert_eq!(indices_elem_ty.get_bit_width(), 32, "Expected list[int32] but got list[int{}]", indices_elem_ty.get_bit_width());
let index = call_ndarray_flatten_index(
generator,
ctx,
*self.0,
indices,
assert_eq!(
indices_elem_ty.get_bit_width(),
32,
"Expected list[int32] but got list[int{}]",
indices_elem_ty.get_bit_width()
);
let index = call_ndarray_flatten_index(generator, ctx, *self.0, indices);
unsafe {
ctx.builder.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
).unwrap()
ctx.builder
.build_in_bounds_gep(
self.base_ptr(ctx, generator),
&[index],
name.unwrap_or_default(),
)
.unwrap()
}
}
@ -1638,12 +1673,10 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
let llvm_usize = generator.get_size_type(ctx.ctx);
let indices_size = indices.size(ctx, generator);
let nidx_leq_ndims = ctx.builder.build_int_compare(
IntPredicate::SLE,
indices_size,
self.0.load_ndims(ctx),
""
).unwrap();
let nidx_leq_ndims = ctx
.builder
.build_int_compare(IntPredicate::SLE, indices_size, self.0.load_ndims(ctx), "")
.unwrap();
ctx.make_assert(
generator,
nidx_leq_ndims,
@ -1668,16 +1701,13 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
self.0.dim_sizes().get_typed_unchecked(ctx, generator, &i, None),
)
};
let dim_idx = ctx.builder
let dim_idx = ctx
.builder
.build_int_z_extend_or_bit_cast(dim_idx, dim_sz.get_type(), "")
.unwrap();
let dim_lt = ctx.builder.build_int_compare(
IntPredicate::SLT,
dim_idx,
dim_sz,
""
).unwrap();
let dim_lt =
ctx.builder.build_int_compare(IntPredicate::SLT, dim_idx, dim_sz, "").unwrap();
ctx.make_assert(
generator,
@ -1691,13 +1721,18 @@ impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> ArrayLikeIndexer<'ctx, Index>
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
)
.unwrap();
unsafe {
self.ptr_offset_unchecked(ctx, generator, indices, name)
}
unsafe { self.ptr_offset_unchecked(ctx, generator, indices, name) }
}
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index> for NDArrayDataProxy<'ctx, '_> {}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeAccessor<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}
impl<'ctx, Index: UntypedArrayLikeAccessor<'ctx>> UntypedArrayLikeMutator<'ctx, Index>
for NDArrayDataProxy<'ctx, '_>
{
}

View File

@ -7,9 +7,9 @@ use crate::{
},
};
use indexmap::IndexMap;
use nac3parser::ast::StrRef;
use std::collections::HashMap;
use indexmap::IndexMap;
pub struct ConcreteTypeStore {
store: Vec<ConcreteTypeEnum>,
@ -202,9 +202,9 @@ impl ConcreteTypeStore {
TypeEnum::TFunc(signature) => {
self.from_signature(unifier, primitives, signature, cache)
}
TypeEnum::TLiteral { values, .. } => ConcreteTypeEnum::TLiteral {
values: values.clone(),
},
TypeEnum::TLiteral { values, .. } => {
ConcreteTypeEnum::TLiteral { values: values.clone() }
}
_ => unreachable!("{:?}", ty_enum.get_type_name()),
};
let index = if let Some(ConcreteType(index)) = cache.get(&ty).unwrap() {
@ -292,9 +292,8 @@ impl ConcreteTypeStore {
.map(|(id, cty)| (*id, self.to_unifier_type(unifier, primitives, *cty, cache)))
.collect::<VarMap>(),
}),
ConcreteTypeEnum::TLiteral { values, .. } => TypeEnum::TLiteral {
values: values.clone(),
loc: None,
ConcreteTypeEnum::TLiteral { values, .. } => {
TypeEnum::TLiteral { values: values.clone(), loc: None }
}
};
let result = unifier.add_ty(result);

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ pub fn call_tan<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -53,7 +53,7 @@ pub fn call_asin<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -85,7 +85,7 @@ pub fn call_acos<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -117,7 +117,7 @@ pub fn call_atan<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -149,7 +149,7 @@ pub fn call_sinh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -181,7 +181,7 @@ pub fn call_cosh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -213,7 +213,7 @@ pub fn call_tanh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -245,7 +245,7 @@ pub fn call_asinh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -277,7 +277,7 @@ pub fn call_acosh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -309,7 +309,7 @@ pub fn call_atanh<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -341,7 +341,7 @@ pub fn call_expm1<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -373,7 +373,7 @@ pub fn call_cbrt<'ctx>(
for attr in ["mustprogress", "nofree", "nosync", "nounwind", "readonly", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -404,7 +404,7 @@ pub fn call_erf<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
@ -434,7 +434,7 @@ pub fn call_erfc<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
@ -465,7 +465,7 @@ pub fn call_j1<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
@ -498,7 +498,7 @@ pub fn call_atan2<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -533,7 +533,7 @@ pub fn call_ldexp<'ctx>(
for attr in ["mustprogress", "nofree", "nounwind", "willreturn"] {
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
);
}
@ -566,7 +566,7 @@ pub fn call_hypot<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
@ -598,7 +598,7 @@ pub fn call_nextafter<'ctx>(
let func = ctx.module.add_function(FN_NAME, fn_type, None);
func.add_attribute(
AttributeLoc::Function,
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0)
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id("nounwind"), 0),
);
func
@ -610,4 +610,4 @@ pub fn call_nextafter<'ctx>(
.map(|v| v.map_left(BasicValueEnum::into_float_value))
.map(Either::unwrap_left)
.unwrap()
}
}

View File

@ -1,5 +1,5 @@
use crate::{
codegen::{classes::ArraySliceValue, expr::*, stmt::*, bool_to_i1, bool_to_i8, CodeGenContext},
codegen::{bool_to_i1, bool_to_i8, classes::ArraySliceValue, expr::*, stmt::*, CodeGenContext},
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type},
@ -210,7 +210,7 @@ pub trait CodeGenerator {
fn bool_to_i1<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i1(&ctx.builder, bool_value)
}
@ -219,7 +219,7 @@ pub trait CodeGenerator {
fn bool_to_i8<'ctx>(
&self,
ctx: &CodeGenContext<'ctx, '_>,
bool_value: IntValue<'ctx>
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
bool_to_i8(&ctx.builder, ctx.ctx, bool_value)
}
@ -239,7 +239,6 @@ impl DefaultCodeGenerator {
}
impl CodeGenerator for DefaultCodeGenerator {
/// Returns the name for this [`CodeGenerator`].
fn get_name(&self) -> &str {
&self.name

View File

@ -2,18 +2,13 @@ use crate::typecheck::typedef::Type;
use super::{
classes::{
ArrayLikeIndexer,
ArrayLikeValue,
ArraySliceValue,
ListValue,
NDArrayValue,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue,
TypedArrayLikeAdapter, UntypedArrayLikeAccessor,
},
CodeGenContext,
CodeGenerator,
llvm_intrinsics,
llvm_intrinsics, CodeGenContext, CodeGenerator,
};
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
use inkwell::{
attributes::{Attribute, AttributeLoc},
context::Context,
@ -25,8 +20,6 @@ use inkwell::{
};
use itertools::Either;
use nac3parser::ast::Expr;
use crate::codegen::classes::TypedArrayLikeAccessor;
use crate::codegen::stmt::gen_for_callback_incrementing;
#[must_use]
pub fn load_irrt(ctx: &Context) -> Module {
@ -70,12 +63,15 @@ pub fn integer_power<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function(symbol, fn_type, None)
});
// throw exception when exp < 0
let ge_zero = ctx.builder.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
).unwrap();
let ge_zero = ctx
.builder
.build_int_compare(
IntPredicate::SGE,
exp,
exp.get_type().const_zero(),
"assert_int_pow_ge_0",
)
.unwrap();
ctx.make_assert(
generator,
ge_zero,
@ -107,12 +103,10 @@ pub fn calculate_len_for_slice_range<'ctx, G: CodeGenerator + ?Sized>(
});
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
).unwrap();
let not_zero = ctx
.builder
.build_int_compare(IntPredicate::NE, step, step.get_type().const_zero(), "range_step_ne")
.unwrap();
ctx.make_assert(
generator,
not_zero,
@ -208,15 +202,18 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
let step = if let Some(v) = generator.gen_expr(ctx, step)? {
v.to_basic_value_enum(ctx, generator, ctx.primitives.int32)?.into_int_value()
} else {
return Ok(None)
return Ok(None);
};
// assert step != 0, throw exception if not
let not_zero = ctx.builder.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
).unwrap();
let not_zero = ctx
.builder
.build_int_compare(
IntPredicate::NE,
step,
step.get_type().const_zero(),
"range_step_ne",
)
.unwrap();
ctx.make_assert(
generator,
not_zero,
@ -226,25 +223,32 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
ctx.current_loc,
);
let len_id = ctx.builder.build_int_sub(length, one, "lenmin1").unwrap();
let neg = ctx.builder.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg").unwrap();
let neg = ctx
.builder
.build_int_compare(IntPredicate::SLT, step, zero, "step_is_neg")
.unwrap();
(
match s {
Some(s) => {
let Some(s) = handle_slice_index_bound(s, ctx, generator, length)? else {
return Ok(None)
return Ok(None);
};
ctx.builder
.build_select(
ctx.builder.build_and(
ctx.builder.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
).unwrap(),
neg,
"should_minus_one",
).unwrap(),
ctx.builder
.build_and(
ctx.builder
.build_int_compare(
IntPredicate::EQ,
s,
length,
"s_eq_len",
)
.unwrap(),
neg,
"should_minus_one",
)
.unwrap(),
ctx.builder.build_int_sub(s, one, "s_min").unwrap(),
s,
"final_start",
@ -252,14 +256,16 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx.builder.build_select(neg, len_id, zero, "stt")
None => ctx
.builder
.build_select(neg, len_id, zero, "stt")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
match e {
Some(e) => {
let Some(e) = handle_slice_index_bound(e, ctx, generator, length)? else {
return Ok(None)
return Ok(None);
};
ctx.builder
.build_select(
@ -271,7 +277,9 @@ pub fn handle_slice_indices<'ctx, G: CodeGenerator>(
.map(BasicValueEnum::into_int_value)
.unwrap()
}
None => ctx.builder.build_select(neg, zero, len_id, "end")
None => ctx
.builder
.build_select(neg, zero, len_id, "end")
.map(BasicValueEnum::into_int_value)
.unwrap(),
},
@ -299,15 +307,16 @@ pub fn handle_slice_index_bound<'ctx, G: CodeGenerator>(
let i = if let Some(v) = generator.gen_expr(ctx, i)? {
v.to_basic_value_enum(ctx, generator, i.custom.unwrap())?
} else {
return Ok(None)
return Ok(None);
};
Ok(Some(ctx
.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap()))
Ok(Some(
ctx.builder
.build_call(func, &[i.into(), length.into()], "bounded_ind")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
.map(Either::unwrap_left)
.unwrap(),
))
}
/// This function handles 'end' **inclusively**.
@ -349,47 +358,33 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
let zero = int32.const_zero();
let one = int32.const_int(1, false);
let dest_arr_ptr = dest_arr.data().base_ptr(ctx, generator);
let dest_arr_ptr = ctx.builder.build_pointer_cast(
dest_arr_ptr,
elem_ptr_type,
"dest_arr_ptr_cast",
).unwrap();
let dest_arr_ptr =
ctx.builder.build_pointer_cast(dest_arr_ptr, elem_ptr_type, "dest_arr_ptr_cast").unwrap();
let dest_len = dest_arr.load_size(ctx, Some("dest.len"));
let dest_len = ctx.builder.build_int_truncate_or_bit_cast(dest_len, int32, "srclen32").unwrap();
let src_arr_ptr = src_arr.data().base_ptr(ctx, generator);
let src_arr_ptr = ctx.builder.build_pointer_cast(
src_arr_ptr,
elem_ptr_type,
"src_arr_ptr_cast",
).unwrap();
let src_arr_ptr =
ctx.builder.build_pointer_cast(src_arr_ptr, elem_ptr_type, "src_arr_ptr_cast").unwrap();
let src_len = src_arr.load_size(ctx, Some("src.len"));
let src_len = ctx.builder.build_int_truncate_or_bit_cast(src_len, int32, "srclen32").unwrap();
// index in bound and positive should be done
// assert if dest.step == 1 then len(src) <= len(dest) else len(src) == len(dest), and
// throw exception if not satisfied
let src_end = ctx.builder
let src_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
src_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_compare(IntPredicate::SLT, src_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(src_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(src_idx.1, one, "e_add_one").unwrap(),
"final_e",
)
.map(BasicValueEnum::into_int_value)
.unwrap();
let dest_end = ctx.builder
let dest_end = ctx
.builder
.build_select(
ctx.builder.build_int_compare(
IntPredicate::SLT,
dest_idx.2,
zero,
"is_neg",
).unwrap(),
ctx.builder.build_int_compare(IntPredicate::SLT, dest_idx.2, zero, "is_neg").unwrap(),
ctx.builder.build_int_sub(dest_idx.1, one, "e_min_one").unwrap(),
ctx.builder.build_int_add(dest_idx.1, one, "e_add_one").unwrap(),
"final_e",
@ -400,24 +395,23 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
calculate_len_for_slice_range(generator, ctx, src_idx.0, src_end, src_idx.2);
let dest_slice_len =
calculate_len_for_slice_range(generator, ctx, dest_idx.0, dest_end, dest_idx.2);
let src_eq_dest = ctx.builder.build_int_compare(
IntPredicate::EQ,
src_slice_len,
dest_slice_len,
"slice_src_eq_dest",
).unwrap();
let src_slt_dest = ctx.builder.build_int_compare(
IntPredicate::SLT,
src_slice_len,
dest_slice_len,
"slice_src_slt_dest",
).unwrap();
let dest_step_eq_one = ctx.builder.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
).unwrap();
let src_eq_dest = ctx
.builder
.build_int_compare(IntPredicate::EQ, src_slice_len, dest_slice_len, "slice_src_eq_dest")
.unwrap();
let src_slt_dest = ctx
.builder
.build_int_compare(IntPredicate::SLT, src_slice_len, dest_slice_len, "slice_src_slt_dest")
.unwrap();
let dest_step_eq_one = ctx
.builder
.build_int_compare(
IntPredicate::EQ,
dest_idx.2,
dest_idx.2.get_type().const_int(1, false),
"slice_dest_step_eq_one",
)
.unwrap();
let cond_1 = ctx.builder.build_and(dest_step_eq_one, src_slt_dest, "slice_cond_1").unwrap();
let cond = ctx.builder.build_or(src_eq_dest, cond_1, "slice_cond").unwrap();
ctx.make_assert(
@ -461,17 +455,14 @@ pub fn list_slice_assignment<'ctx, G: CodeGenerator + ?Sized>(
.unwrap()
};
// update length
let need_update = ctx.builder
.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update")
.unwrap();
let need_update =
ctx.builder.build_int_compare(IntPredicate::NE, new_len, dest_len, "need_update").unwrap();
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let update_bb = ctx.ctx.append_basic_block(current, "update");
let cont_bb = ctx.ctx.append_basic_block(current, "cont");
ctx.builder.build_conditional_branch(need_update, update_bb, cont_bb).unwrap();
ctx.builder.position_at_end(update_bb);
let new_len = ctx.builder
.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len")
.unwrap();
let new_len = ctx.builder.build_int_z_extend_or_bit_cast(new_len, size_ty, "new_len").unwrap();
dest_arr.store_size(ctx, generator, new_len);
ctx.builder.build_unconditional_branch(cont_bb).unwrap();
ctx.builder.position_at_end(cont_bb);
@ -488,7 +479,8 @@ pub fn call_isinf<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function("__nac3_isinf", fn_type, None)
});
let ret = ctx.builder
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isinf")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
@ -509,7 +501,8 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
ctx.module.add_function("__nac3_isnan", fn_type, None)
});
let ret = ctx.builder
let ret = ctx
.builder
.build_call(intrinsic_fn, &[v.into()], "isnan")
.map(CallSiteValue::try_as_basic_value)
.map(|v| v.map_left(BasicValueEnum::into_int_value))
@ -520,10 +513,7 @@ pub fn call_isnan<'ctx, G: CodeGenerator + ?Sized>(
}
/// Generates a call to `gamma` in IR. Returns an `f64` representing the result.
pub fn call_gamma<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
pub fn call_gamma<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gamma").unwrap_or_else(|| {
@ -540,10 +530,7 @@ pub fn call_gamma<'ctx>(
}
/// Generates a call to `gammaln` in IR. Returns an `f64` representing the result.
pub fn call_gammaln<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
pub fn call_gammaln<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_gammaln").unwrap_or_else(|| {
@ -560,10 +547,7 @@ pub fn call_gammaln<'ctx>(
}
/// Generates a call to `j0` in IR. Returns an `f64` representing the result.
pub fn call_j0<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
v: FloatValue<'ctx>,
) -> FloatValue<'ctx> {
pub fn call_j0<'ctx>(ctx: &CodeGenContext<'ctx, '_>, v: FloatValue<'ctx>) -> FloatValue<'ctx> {
let llvm_f64 = ctx.ctx.f64_type();
let intrinsic_fn = ctx.module.get_function("__nac3_j0").unwrap_or_else(|| {
@ -583,7 +567,7 @@ pub fn call_j0<'ctx>(
/// calculated total size.
///
/// * `dims` - An [`ArrayLikeIndexer`] containing the size of each dimension.
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// * `range` - The dimension index to begin and end (exclusively) calculating the dimensions for,
/// or [`None`] if starting from the first dimension and ending at the last dimension respectively.
pub fn call_ndarray_calc_size<'ctx, G, Dims>(
generator: &G,
@ -591,9 +575,10 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
dims: &Dims,
(begin, end): (Option<IntValue<'ctx>>, Option<IntValue<'ctx>>),
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>, {
where
G: CodeGenerator + ?Sized,
Dims: ArrayLikeIndexer<'ctx>,
{
let llvm_i64 = ctx.ctx.i64_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -602,19 +587,14 @@ pub fn call_ndarray_calc_size<'ctx, G, Dims>(
let ndarray_calc_size_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_size",
64 => "__nac3_ndarray_calc_size64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_size_fn_t = llvm_usize.fn_type(
&[
llvm_pi64.into(),
llvm_usize.into(),
llvm_usize.into(),
llvm_usize.into(),
],
&[llvm_pi64.into(), llvm_usize.into(), llvm_usize.into(), llvm_usize.into()],
false,
);
let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name)
.unwrap_or_else(|| {
let ndarray_calc_size_fn =
ctx.module.get_function(ndarray_calc_size_fn_name).unwrap_or_else(|| {
ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None)
});
@ -658,30 +638,22 @@ pub fn call_ndarray_calc_nd_indices<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_nd_indices_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_nd_indices",
64 => "__nac3_ndarray_calc_nd_indices64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
],
false,
);
let ndarray_calc_nd_indices_fn =
ctx.module.get_function(ndarray_calc_nd_indices_fn_name).unwrap_or_else(|| {
let fn_type = llvm_void.fn_type(
&[llvm_usize.into(), llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
ctx.module.add_function(ndarray_calc_nd_indices_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let indices = ctx.builder.build_array_alloca(
llvm_i32,
ndarray_num_dims,
"",
).unwrap();
let indices = ctx.builder.build_array_alloca(llvm_i32, ndarray_num_dims, "").unwrap();
ctx.builder
.build_call(
@ -709,9 +681,10 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
ndarray: NDArrayValue<'ctx>,
indices: &Indices,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>, {
where
G: CodeGenerator + ?Sized,
Indices: ArrayLikeIndexer<'ctx>,
{
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
@ -734,26 +707,23 @@ fn call_ndarray_flatten_index_impl<'ctx, G, Indices>(
let ndarray_flatten_index_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_flatten_index",
64 => "__nac3_ndarray_flatten_index64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_flatten_index_fn = ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_usize.into(),
],
false,
);
let ndarray_flatten_index_fn =
ctx.module.get_function(ndarray_flatten_index_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_usize.into()],
false,
);
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
ctx.module.add_function(ndarray_flatten_index_fn_name, fn_type, None)
});
let ndarray_num_dims = ndarray.load_ndims(ctx);
let ndarray_dims = ndarray.dim_sizes();
let index = ctx.builder
let index = ctx
.builder
.build_call(
ndarray_flatten_index_fn,
&[
@ -784,16 +754,11 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>(
ndarray: NDArrayValue<'ctx>,
indices: &Index,
) -> IntValue<'ctx>
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>, {
call_ndarray_flatten_index_impl(
generator,
ctx,
ndarray,
indices,
)
where
G: CodeGenerator + ?Sized,
Index: ArrayLikeIndexer<'ctx>,
{
call_ndarray_flatten_index_impl(generator, ctx, ndarray, indices)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
@ -810,22 +775,23 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
@ -846,36 +812,22 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
};
let llvm_usize_const_one = llvm_usize.const_int(1, false);
let lhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let rhs_eqz = ctx.builder.build_int_compare(
IntPredicate::EQ,
rhs_dim_sz,
llvm_usize_const_one,
"",
).unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(
lhs_eqz,
rhs_eqz,
""
).unwrap();
let lhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let rhs_eqz = ctx
.builder
.build_int_compare(IntPredicate::EQ, rhs_dim_sz, llvm_usize_const_one, "")
.unwrap();
let lhs_or_rhs_eqz = ctx.builder.build_or(lhs_eqz, rhs_eqz, "").unwrap();
let lhs_eq_rhs = ctx.builder.build_int_compare(
IntPredicate::EQ,
lhs_dim_sz,
rhs_dim_sz,
""
).unwrap();
let lhs_eq_rhs = ctx
.builder
.build_int_compare(IntPredicate::EQ, lhs_dim_sz, rhs_dim_sz, "")
.unwrap();
let is_compatible = ctx.builder.build_or(
lhs_or_rhs_eqz,
lhs_eq_rhs,
""
).unwrap();
let is_compatible = ctx.builder.build_or(lhs_or_rhs_eqz, lhs_eq_rhs, "").unwrap();
ctx.make_assert(
generator,
@ -889,7 +841,8 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
Ok(())
},
llvm_usize.const_int(1, false),
).unwrap();
)
.unwrap();
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
@ -923,7 +876,11 @@ pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
pub fn call_ndarray_calc_broadcast_index<
'ctx,
G: CodeGenerator + ?Sized,
BroadcastIdx: UntypedArrayLikeAccessor<'ctx>,
>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
@ -937,21 +894,17 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
bw => unreachable!("Unsupported size type bit width: {}", bw),
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_pi32.into(),
],
false,
);
let ndarray_calc_broadcast_fn =
ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[llvm_pusize.into(), llvm_usize.into(), llvm_pi32.into(), llvm_pi32.into()],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
@ -959,23 +912,13 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_zero(),
None
)
broadcast_idx.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.into(),
],
&[array_dims.into(), array_ndims.into(), broadcast_idx_ptr.into(), out_idx.into()],
"",
)
.unwrap();
@ -985,4 +928,4 @@ pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, Broadc
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
}

View File

@ -1,35 +1,35 @@
use inkwell::AddressSpace;
use crate::codegen::CodeGenContext;
use inkwell::context::Context;
use inkwell::intrinsics::Intrinsic;
use inkwell::types::AnyTypeEnum::IntType;
use inkwell::types::FloatType;
use inkwell::values::{BasicValueEnum, CallSiteValue, FloatValue, IntValue, PointerValue};
use inkwell::AddressSpace;
use itertools::Either;
use crate::codegen::CodeGenContext;
/// Returns the string representation for the floating-point type `ft` when used in intrinsic
/// functions.
fn get_float_intrinsic_repr(ctx: &Context, ft: FloatType) -> &'static str {
// Standard LLVM floating-point types
if ft == ctx.f16_type() {
return "f16"
return "f16";
}
if ft == ctx.f32_type() {
return "f32"
return "f32";
}
if ft == ctx.f64_type() {
return "f64"
return "f64";
}
if ft == ctx.f128_type() {
return "f128"
return "f128";
}
// Non-standard floating-point types
if ft == ctx.x86_f80_type() {
return "f80"
return "f80";
}
if ft == ctx.ppc_f128_type() {
return "ppcf128"
return "ppcf128";
}
unreachable!()
@ -69,9 +69,7 @@ pub fn call_stackrestore<'ctx>(ctx: &CodeGenContext<'ctx, '_>, ptr: PointerValue
.and_then(|intrinsic| intrinsic.get_declaration(&ctx.module, &[llvm_p0i8.into()]))
.unwrap();
ctx.builder
.build_call(intrinsic_fn, &[ptr.into()], "")
.unwrap();
ctx.builder.build_call(intrinsic_fn, &[ptr.into()], "").unwrap();
}
/// Invokes the [`llvm.abs`](https://llvm.org/docs/LangRef.html#llvm-abs-intrinsic) intrinsic.
@ -232,10 +230,12 @@ pub fn call_memcpy<'ctx>(
let llvm_len_t = len.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
))
.and_then(|intrinsic| {
intrinsic.get_declaration(
&ctx.module,
&[llvm_dest_t.into(), llvm_src_t.into(), llvm_len_t.into()],
)
})
.unwrap();
ctx.builder
@ -315,10 +315,9 @@ pub fn call_float_powi<'ctx>(
let llvm_power_t = power.get_type();
let intrinsic_fn = Intrinsic::find(FN_NAME)
.and_then(|intrinsic| intrinsic.get_declaration(
&ctx.module,
&[llvm_val_t.into(), llvm_power_t.into()],
))
.and_then(|intrinsic| {
intrinsic.get_declaration(&ctx.module, &[llvm_val_t.into(), llvm_power_t.into()])
})
.unwrap();
ctx.builder
@ -442,7 +441,6 @@ pub fn call_float_exp2<'ctx>(
.unwrap()
}
/// Invokes the [`llvm.log`](https://llvm.org/docs/LangRef.html#llvm-log-intrinsic) intrinsic.
pub fn call_float_log<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,
@ -672,7 +670,7 @@ pub fn call_float_round<'ctx>(
.unwrap()
}
/// Invokes the
/// Invokes the
/// [`llvm.roundeven`](https://llvm.org/docs/LangRef.html#llvm-roundeven-intrinsic) intrinsic.
pub fn call_float_roundeven<'ctx>(
ctx: &CodeGenContext<'ctx, '_>,

View File

@ -1,12 +1,7 @@
use crate::{
codegen::classes::{ListType, NDArrayType, ProxyType, RangeType},
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_var_tys,
TopLevelContext,
TopLevelDef,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef},
typecheck::{
type_inferencer::{CodeLocation, PrimitiveStore},
typedef::{CallId, FuncArg, Type, TypeEnum, Unifier},
@ -14,24 +9,22 @@ use crate::{
};
use crossbeam::channel::{unbounded, Receiver, Sender};
use inkwell::{
AddressSpace,
IntPredicate,
OptimizationLevel,
attributes::{Attribute, AttributeLoc},
basic_block::BasicBlock,
builder::Builder,
context::Context,
debug_info::{
AsDIScope, DICompileUnit, DIFlagsConstants, DIScope, DISubprogram, DebugInfoBuilder,
},
module::Module,
passes::PassBuilderOptions,
targets::{CodeModel, RelocMode, Target, TargetMachine, TargetTriple},
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, FunctionValue, IntValue, PhiValue, PointerValue},
debug_info::{
DebugInfoBuilder, DICompileUnit, DISubprogram, AsDIScope, DIFlagsConstants, DIScope
},
AddressSpace, IntPredicate, OptimizationLevel,
};
use itertools::Itertools;
use nac3parser::ast::{Stmt, StrRef, Location};
use nac3parser::ast::{Location, Stmt, StrRef};
use parking_lot::{Condvar, Mutex};
use std::collections::{HashMap, HashSet};
use std::sync::{
@ -91,7 +84,6 @@ pub struct CodeGenTargetMachineOptions {
}
impl CodeGenTargetMachineOptions {
/// Creates an instance of [`CodeGenTargetMachineOptions`] using the triple of the host machine.
/// Other options are set to defaults.
#[must_use]
@ -120,13 +112,11 @@ impl CodeGenTargetMachineOptions {
///
/// See [`Target::create_target_machine`].
#[must_use]
pub fn create_target_machine(
&self,
level: OptimizationLevel,
) -> Option<TargetMachine> {
pub fn create_target_machine(&self, level: OptimizationLevel) -> Option<TargetMachine> {
let triple = TargetTriple::create(self.triple.as_str());
let target = Target::from_triple(&triple)
.unwrap_or_else(|_| panic!("could not create target from target triple {}", self.triple));
let target = Target::from_triple(&triple).unwrap_or_else(|_| {
panic!("could not create target from target triple {}", self.triple)
});
target.create_target_machine(
&triple,
@ -134,7 +124,7 @@ impl CodeGenTargetMachineOptions {
self.features.as_str(),
level,
self.reloc_mode,
self.code_model
self.code_model,
)
}
}
@ -205,7 +195,6 @@ pub struct CodeGenContext<'ctx, 'a> {
}
impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
/// Whether the [current basic block][Builder::get_insert_block] referenced by `builder`
/// contains a [terminator statement][BasicBlock::get_terminator].
pub fn is_terminated(&self) -> bool {
@ -251,7 +240,6 @@ pub struct WorkerRegistry {
}
impl WorkerRegistry {
/// Creates workers for this registry.
#[must_use]
pub fn create_workers<G: CodeGenerator + Send + 'static>(
@ -373,7 +361,11 @@ impl WorkerRegistry {
*self.task_count.lock() -= 1;
self.wait_condvar.notify_all();
}
assert!(errors.is_empty(), "Codegen error: {}", errors.into_iter().sorted().join("\n----------\n"));
assert!(
errors.is_empty(),
"Codegen error: {}",
errors.into_iter().sorted().join("\n----------\n")
);
let result = module.verify();
if let Err(err) = result {
@ -386,13 +378,20 @@ impl WorkerRegistry {
.llvm_options
.target
.create_target_machine(self.llvm_options.opt_level)
.unwrap_or_else(|| panic!("could not create target machine from properties {:?}", self.llvm_options.target));
.unwrap_or_else(|| {
panic!(
"could not create target machine from properties {:?}",
self.llvm_options.target
)
});
let passes = format!("default<O{}>", self.llvm_options.opt_level as u32);
let result = module.run_passes(passes.as_str(), &target_machine, pass_options);
if let Err(err) = result {
panic!("Failed to run optimization for module `{}`: {}",
module.get_name().to_str().unwrap(),
err.to_string());
panic!(
"Failed to run optimization for module `{}`: {}",
module.get_name().to_str().unwrap(),
err.to_string()
);
}
f.run(&module);
@ -436,9 +435,9 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let result = match &*ty_enum {
TObj { obj_id, fields, .. } => {
// check to avoid treating non-class primitives as classes
if obj_id.0 <= PRIMITIVE_DEF_IDS.max_id().0 {
if PrimDef::contains_id(*obj_id) {
return match &*unifier.get_ty_immutable(ty) {
TObj { obj_id, params, .. } if *obj_id == PRIMITIVE_DEF_IDS.option => {
TObj { obj_id, params, .. } if *obj_id == PrimDef::Option.id() => {
get_llvm_type(
ctx,
module,
@ -452,23 +451,20 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
.into()
}
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx,
module,
generator,
unifier,
top_level,
type_cache,
dtype,
ctx, module, generator, unifier, top_level, type_cache, dtype,
);
NDArrayType::new(generator, ctx, element_type).as_base_type().into()
}
_ => unreachable!("LLVM type for primitive {} is missing", unifier.stringify(ty)),
}
_ => unreachable!(
"LLVM type for primitive {} is missing",
unifier.stringify(ty)
),
};
}
// a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read();
@ -484,7 +480,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
let struct_type = ctx.opaque_struct_type(&name);
type_cache.insert(
unifier.get_representative(ty),
struct_type.ptr_type(AddressSpace::default()).into()
struct_type.ptr_type(AddressSpace::default()).into(),
);
let fields = fields_list
.iter()
@ -503,24 +499,21 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
};
return ty
return ty;
}
TTuple { ty } => {
// a struct with fields in the order present in the tuple
let fields = ty
.iter()
.map(|ty| {
get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, *ty,
)
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty)
})
.collect_vec();
ctx.struct_type(&fields, false).into()
}
TList { ty } => {
let element_type = get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, *ty,
);
let element_type =
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, *ty);
ListType::new(generator, ctx, element_type).as_base_type().into()
}
@ -558,7 +551,7 @@ fn get_llvm_abi_type<'ctx, G: CodeGenerator + ?Sized>(
ctx.bool_type().into()
} else {
get_llvm_type(ctx, module, generator, unifier, top_level, type_cache, ty)
}
};
}
/// Whether `sret` is needed for a return value with type `ty`.
@ -574,8 +567,9 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
match ty {
BasicTypeEnum::IntType(_) | BasicTypeEnum::PointerType(_) => false,
BasicTypeEnum::FloatType(_) if maybe_large => false,
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 =>
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false)),
BasicTypeEnum::StructType(ty) if maybe_large && ty.count_fields() <= 2 => {
ty.get_field_types().iter().any(|ty| need_sret_impl(*ty, false))
}
_ => true,
}
}
@ -583,14 +577,18 @@ fn need_sret(ty: BasicTypeEnum) -> bool {
}
/// Implementation for generating LLVM IR for a function.
pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>> (
pub fn gen_func_impl<
'ctx,
G: CodeGenerator,
F: FnOnce(&mut G, &mut CodeGenContext) -> Result<(), String>,
>(
context: &'ctx Context,
generator: &mut G,
registry: &WorkerRegistry,
builder: Builder<'ctx>,
module: Module<'ctx>,
task: CodeGenTask,
codegen_function: F
codegen_function: F,
) -> Result<(Builder<'ctx>, Module<'ctx>, FunctionValue<'ctx>), (Builder<'ctx>, String)> {
let top_level_ctx = registry.top_level_ctx.clone();
let static_value_store = registry.static_value_store.clone();
@ -654,7 +652,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum()
Some(t) => t.as_basic_type_enum(),
}
}),
(primitives.range, RangeType::new(context).as_base_type().into()),
@ -671,7 +669,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::default()).as_basic_type_enum()
}
})
}),
]
.iter()
.copied()
@ -679,8 +677,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead
let ConcreteTypeEnum::TFunc { args, ret, .. } =
task.store.get(task.signature) else {
let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) else {
unreachable!()
};
@ -697,7 +694,16 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let ret_type = if unifier.unioned(ret, primitives.none) {
None
} else {
Some(get_llvm_abi_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret))
Some(get_llvm_abi_type(
context,
&module,
generator,
&mut unifier,
top_level_ctx.as_ref(),
&mut type_cache,
&primitives,
ret,
))
};
let has_sret = ret_type.map_or(false, |ty| need_sret(ty));
@ -724,7 +730,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let fn_type = match ret_type {
Some(ret_type) if !has_sret => ret_type.fn_type(&params, false),
_ => context.void_type().fn_type(&params, false)
_ => context.void_type().fn_type(&params, false),
};
let symbol = &task.symbol_name;
@ -739,9 +745,13 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
fn_val.set_personality_function(personality);
}
if has_sret {
fn_val.add_attribute(AttributeLoc::Param(0),
context.create_type_attribute(Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum()));
fn_val.add_attribute(
AttributeLoc::Param(0),
context.create_type_attribute(
Attribute::get_named_enum_kind_id("sret"),
ret_type.unwrap().as_any_type_enum(),
),
);
}
let init_bb = context.append_basic_block(fn_val, "init");
@ -761,9 +771,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&mut type_cache,
arg.ty,
);
let alloca = builder
.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string()))
.unwrap();
let alloca =
builder.build_alloca(local_type, &format!("{}.addr", &arg.name.to_string())).unwrap();
// Remap boolean parameters into i8
let param = if local_type.is_int_type() && param.is_int_value() {
@ -774,7 +783,8 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
bool_to_i8(&builder, context, param_val)
} else {
param_val
}.into()
}
.into()
} else {
param
};
@ -808,10 +818,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
&task
.body
.first()
.map_or_else(
|| "<nac3_internal>".to_string(),
|f| f.location.file.0.to_string(),
),
.map_or_else(|| "<nac3_internal>".to_string(), |f| f.location.file.0.to_string()),
/* directory */ "",
/* producer */ "NAC3",
/* is_optimized */ registry.llvm_options.opt_level != OptimizationLevel::None,
@ -884,10 +891,10 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
row as u32,
col as u32,
func_scope.as_debug_info_scope(),
None
None,
);
code_gen_context.builder.set_current_debug_location(loc);
let result = codegen_function(generator, &mut code_gen_context);
// after static analysis, only void functions can have no return at the end.
@ -949,7 +956,7 @@ fn bool_to_i1<'ctx>(builder: &Builder<'ctx>, bool_value: IntValue<'ctx>) -> IntV
fn bool_to_i8<'ctx>(
builder: &Builder<'ctx>,
ctx: &'ctx Context,
bool_value: IntValue<'ctx>
bool_value: IntValue<'ctx>,
) -> IntValue<'ctx> {
let value_bits = bool_value.get_type().get_bit_width();
match value_bits {
@ -965,7 +972,7 @@ fn bool_to_i8<'ctx>(
bool_value.get_type().const_zero(),
"",
)
.unwrap()
.unwrap(),
),
}
}
@ -991,11 +998,18 @@ fn gen_in_range_check<'ctx>(
stop: IntValue<'ctx>,
step: IntValue<'ctx>,
) -> IntValue<'ctx> {
let sign = ctx.builder.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "").unwrap();
let lo = ctx.builder.build_select(sign, value, stop, "")
let sign = ctx
.builder
.build_int_compare(IntPredicate::SGT, step, ctx.ctx.i32_type().const_zero(), "")
.unwrap();
let lo = ctx
.builder
.build_select(sign, value, stop, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let hi = ctx.builder.build_select(sign, stop, value, "")
let hi = ctx
.builder
.build_select(sign, stop, value, "")
.map(BasicValueEnum::into_int_value)
.unwrap();

File diff suppressed because it is too large Load Diff

View File

@ -10,12 +10,7 @@ use crate::{
expr::gen_binop_expr,
gen_in_range_check,
},
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_var_tys,
TopLevelDef,
},
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, Type, TypeEnum},
};
use inkwell::{
@ -116,13 +111,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
ctx.var_assignment.insert(*id, (ptr, None, counter));
ptr
}
}
},
ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
} else {
return Ok(None)
return Ok(None);
};
let BasicValueEnum::PointerValue(ptr) = val else {
unreachable!();
@ -136,7 +131,8 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
],
name.unwrap_or(""),
)
}.unwrap()
}
.unwrap()
}
ExprKind::Subscript { value, slice, .. } => {
match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() {
@ -153,11 +149,13 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
.unwrap()
.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?
.into_int_value();
let raw_index = ctx.builder
let raw_index = ctx
.builder
.build_int_s_extend(raw_index, generator.get_size_type(ctx.ctx), "sext")
.unwrap();
// handle negative index
let is_negative = ctx.builder
let is_negative = ctx
.builder
.build_int_compare(
IntPredicate::SLT,
raw_index,
@ -173,13 +171,9 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
.unwrap();
// unsigned less than is enough, because negative index after adjustment is
// bigger than the length (for unsigned cmp)
let bound_check = ctx.builder
.build_int_compare(
IntPredicate::ULT,
index,
len,
"inbound",
)
let bound_check = ctx
.builder
.build_int_compare(IntPredicate::ULT, index, len, "inbound")
.unwrap();
ctx.make_assert(
generator,
@ -192,7 +186,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
v.data().ptr_offset(ctx, generator, &index, name)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!()
}
@ -215,7 +209,8 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
match &target.node {
ExprKind::Tuple { elts, .. } => {
let BasicValueEnum::StructValue(v) =
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? else {
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
else {
unreachable!()
};
@ -230,9 +225,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ExprKind::Subscript { value: ls, slice, .. }
if matches!(&slice.node, ExprKind::Slice { .. }) =>
{
let ExprKind::Slice { lower, upper, step } = &slice.node else {
unreachable!()
};
let ExprKind::Slice { lower, upper, step } = &slice.node else { unreachable!() };
let ls = generator
.gen_expr(ctx, ls)?
@ -240,21 +233,18 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.to_basic_value_enum(ctx, generator, ls.custom.unwrap())?
.into_pointer_value();
let ls = ListValue::from_ptr_val(ls, llvm_usize, None);
let Some((start, end, step)) = handle_slice_indices(
lower,
upper,
step,
ctx,
generator,
ls.load_size(ctx, None),
)? else { return Ok(()) };
let Some((start, end, step)) =
handle_slice_indices(lower, upper, step, ctx, generator, ls.load_size(ctx, None))?
else {
return Ok(());
};
let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value();
let value = ListValue::from_ptr_val(value, llvm_usize, None);
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
}
_ => unreachable!(),
@ -268,7 +258,10 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
ctx,
generator,
value.load_size(ctx, None),
)? else { return Ok(()) };
)?
else {
return Ok(());
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
}
_ => {
@ -278,7 +271,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
String::from("target.addr")
};
let Some(ptr) = generator.gen_store_target(ctx, target, Some(name.as_str()))? else {
return Ok(())
return Ok(());
};
if let ExprKind::Name { id, .. } = &target.node {
@ -301,9 +294,7 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
unreachable!()
};
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else { unreachable!() };
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -316,11 +307,8 @@ pub fn gen_for<G: CodeGenerator>(
let body_bb = ctx.ctx.append_basic_block(current, "for.body");
let cont_bb = ctx.ctx.append_basic_block(current, "for.end");
// if there is no orelse, we just go to cont_bb
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
ctx.ctx.append_basic_block(current, "for.orelse")
};
let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "for.orelse") };
// Whether the iterable is a range() expression
let is_iterable_range_expr = ctx.unifier.unioned(iter.custom.unwrap(), ctx.primitives.range);
@ -334,20 +322,17 @@ pub fn gen_for<G: CodeGenerator>(
let loop_bb = ctx.loop_target.replace((incr_bb, cont_bb));
let iter_val = if let Some(v) = generator.gen_expr(ctx, iter)? {
v.to_basic_value_enum(
ctx,
generator,
iter.custom.unwrap(),
)?
v.to_basic_value_enum(ctx, generator, iter.custom.unwrap())?
} else {
return Ok(())
return Ok(());
};
if is_iterable_range_expr {
let iter_val = RangeValue::from_ptr_val(iter_val.into_pointer_value(), Some("range"));
// Internal variable for loop; Cannot be assigned
let i = generator.gen_var_alloc(ctx, int32.into(), Some("for.i.addr"))?;
// Variable declared in "target" expression of the loop; Can be reassigned *or* shadowed
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))? else {
let Some(target_i) = generator.gen_store_target(ctx, target, Some("for.target.addr"))?
else {
unreachable!()
};
let (start, stop, step) = destructure_range(ctx, iter_val);
@ -355,16 +340,15 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.build_store(i, start).unwrap();
// Check "If step is zero, ValueError is raised."
let rangenez = ctx.builder
.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "")
.unwrap();
let rangenez =
ctx.builder.build_int_compare(IntPredicate::NE, step, int32.const_zero(), "").unwrap();
ctx.make_assert(
generator,
rangenez,
"ValueError",
"range() arg 3 must not be zero",
[None, None, None],
ctx.current_loc
ctx.current_loc,
);
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
@ -385,7 +369,8 @@ pub fn gen_for<G: CodeGenerator>(
}
ctx.builder.position_at_end(incr_bb);
let next_i = ctx.builder
let next_i = ctx
.builder
.build_int_add(
ctx.builder.build_load(i, "").map(BasicValueEnum::into_int_value).unwrap(),
step,
@ -410,13 +395,14 @@ pub fn gen_for<G: CodeGenerator>(
.build_gep_and_load(
iter_val.into_pointer_value(),
&[zero, int32.const_int(1, false)],
Some("len")
Some("len"),
)
.into_int_value();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
ctx.builder.position_at_end(cond_bb);
let index = ctx.builder
let index = ctx
.builder
.build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value)
.unwrap();
@ -424,7 +410,8 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.build_conditional_branch(cmp, body_bb, orelse_bb).unwrap();
ctx.builder.position_at_end(incr_bb);
let index = ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let index =
ctx.builder.build_load(index_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let inc = ctx.builder.build_int_add(index, size_t.const_int(1, true), "inc").unwrap();
ctx.builder.build_store(index_addr, inc).unwrap();
ctx.builder.build_unconditional_branch(cond_bb).unwrap();
@ -433,7 +420,8 @@ pub fn gen_for<G: CodeGenerator>(
let arr_ptr = ctx
.build_gep_and_load(iter_val.into_pointer_value(), &[zero, zero], Some("arr.addr"))
.into_pointer_value();
let index = ctx.builder
let index = ctx
.builder
.build_load(index_addr, "for.index")
.map(BasicValueEnum::into_int_value)
.unwrap();
@ -496,13 +484,13 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
body: BodyFn,
update: UpdateFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
where
G: CodeGenerator + ?Sized,
I: Clone,
InitFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<I, String>,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
UpdateFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>,
{
let current_bb = ctx.builder.get_insert_block().unwrap();
let init_bb = ctx.ctx.insert_basic_block_after(current_bb, "for.init");
@ -528,9 +516,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
let cond = cond(generator, ctx, loop_var.clone())?;
assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width());
if !ctx.is_terminated() {
ctx.builder
.build_conditional_branch(cond, body_bb, cont_bb)
.unwrap();
ctx.builder.build_conditional_branch(cond, body_bb, cont_bb).unwrap();
}
ctx.builder.position_at_end(body_bb);
@ -551,7 +537,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
Ok(())
}
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
/// Generates a C-style monotonically-increasing `for` construct using lambdas, similar to the
/// following C code:
///
/// ```c
@ -560,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// }
/// ```
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum
/// value should be treated as inclusive (as opposed to exclusive).
@ -574,9 +560,9 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
body: BodyFn,
incr_val: IntValue<'ctx>,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
where
G: CodeGenerator + ?Sized,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = init_val.get_type();
@ -590,38 +576,23 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
Ok(i_addr)
},
|_, ctx, i_addr| {
let cmp_op = if max_val.1 {
IntPredicate::ULE
} else {
IntPredicate::ULT
};
let cmp_op = if max_val.1 { IntPredicate::ULE } else { IntPredicate::ULT };
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let max_val = ctx.builder
.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "")
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let max_val =
ctx.builder.build_int_z_extend_or_bit_cast(max_val.0, init_val_t, "").unwrap();
Ok(ctx.builder.build_int_compare(cmp_op, i, max_val, "").unwrap())
},
|generator, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body(generator, ctx, i)
},
|_, ctx, i_addr| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let incr_val = ctx.builder
.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "")
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let incr_val =
ctx.builder.build_int_z_extend_or_bit_cast(incr_val, init_val_t, "").unwrap();
let i = ctx.builder.build_int_add(i, incr_val, "").unwrap();
ctx.builder.build_store(i_addr, i).unwrap();
@ -632,21 +603,21 @@ pub fn gen_for_callback_incrementing<'ctx, 'a, G, BodyFn>(
/// Generates a `for` construct over a `range`-like iterable using lambdas, similar to the following
/// C code:
///
///
/// ```c
/// bool incr = start_fn() <= end_fn();
/// for (int i = start_fn(); i /* < or > */ end_fn(); i += step_fn()) {
/// body_fn(i);
/// }
/// ```
///
///
/// - `is_unsigned`: Whether to treat the values of the `range` as unsigned.
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// - `start_fn`: A lambda of IR statements that retrieves the `start` value of the `range`-like
/// iterable.
/// - `stop_fn`: A lambda of IR statements that retrieves the `stop` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `stop_inclusive`: Whether the stop value should be treated as inclusive.
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// - `step_fn`: A lambda of IR statements that retrieves the `step` value of the `range`-like
/// iterable. This value will be extended to the size of `start`.
/// - `body_fn`: A lambda of IR statements within the loop body.
pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
@ -658,16 +629,14 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
step_fn: StepFn,
body_fn: BodyFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
where
G: CodeGenerator + ?Sized,
StartFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StopFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
StepFn: Fn(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
BodyFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result<(), String>,
{
let init_val_t = start_fn(generator, ctx)
.map(IntValue::get_type)
.unwrap();
let init_val_t = start_fn(generator, ctx).map(IntValue::get_type).unwrap();
gen_for_callback(
generator,
@ -688,12 +657,15 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
ctx.builder.build_int_s_extend(stop, start.get_type(), "").unwrap()
};
let incr = ctx.builder.build_int_compare(
if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE },
start,
stop,
"",
).unwrap();
let incr = ctx
.builder
.build_int_compare(
if is_unsigned { IntPredicate::ULE } else { IntPredicate::SLE },
start,
stop,
"",
)
.unwrap();
Ok((i_addr, incr))
},
@ -705,10 +677,7 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
(false, false) => (IntPredicate::SLT, IntPredicate::SGT),
};
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let stop = stop_fn(generator, ctx)?;
let stop = if stop.get_type().get_bit_width() == i.get_type().get_bit_width() {
stop
@ -718,14 +687,11 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
ctx.builder.build_int_s_extend(stop, i.get_type(), "").unwrap()
};
let i_lt_end = ctx.builder
.build_int_compare(lt_cmp_op, i, stop, "")
.unwrap();
let i_gt_end = ctx.builder
.build_int_compare(gt_cmp_op, i, stop, "")
.unwrap();
let i_lt_end = ctx.builder.build_int_compare(lt_cmp_op, i, stop, "").unwrap();
let i_gt_end = ctx.builder.build_int_compare(gt_cmp_op, i, stop, "").unwrap();
let cond = ctx.builder
let cond = ctx
.builder
.build_select(incr, i_lt_end, i_gt_end, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
@ -733,18 +699,12 @@ pub fn gen_for_range_callback<'ctx, 'a, G, StartFn, StopFn, StepFn, BodyFn>(
Ok(cond)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
body_fn(generator, ctx, i)
},
|generator, ctx, (i_addr, _)| {
let i = ctx.builder
.build_load(i_addr, "")
.map(BasicValueEnum::into_int_value)
.unwrap();
let i = ctx.builder.build_load(i_addr, "").map(BasicValueEnum::into_int_value).unwrap();
let incr_val = step_fn(generator, ctx)?;
let incr_val = if incr_val.get_type().get_bit_width() == i.get_type().get_bit_width() {
@ -769,9 +729,7 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else {
unreachable!()
};
let StmtKind::While { test, body, orelse, .. } = &stmt.node else { unreachable!() };
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -782,8 +740,11 @@ pub fn gen_while<G: CodeGenerator>(
let body_bb = ctx.ctx.append_basic_block(current, "while.body");
let cont_bb = ctx.ctx.append_basic_block(current, "while.cont");
// if there is no orelse, we just go to cont_bb
let orelse_bb =
if orelse.is_empty() { cont_bb } else { ctx.ctx.append_basic_block(current, "while.orelse") };
let orelse_bb = if orelse.is_empty() {
cont_bb
} else {
ctx.ctx.append_basic_block(current, "while.orelse")
};
// store loop bb information and restore it later
let loop_bb = ctx.loop_target.replace((test_bb, cont_bb));
ctx.builder.build_unconditional_branch(test_bb).unwrap();
@ -796,11 +757,9 @@ pub fn gen_while<G: CodeGenerator>(
ctx.builder.build_unreachable().unwrap();
}
return Ok(())
};
let BasicValueEnum::IntValue(test) = test else {
unreachable!()
return Ok(());
};
let BasicValueEnum::IntValue(test) = test else { unreachable!() };
ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -853,12 +812,12 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
then_fn: ThenFn,
else_fn: ElseFn,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
where
G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
R: BasicValue<'ctx>,
where
G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<Option<R>, String>,
R: BasicValue<'ctx>,
{
let current_bb = ctx.builder.get_insert_block().unwrap();
@ -893,8 +852,8 @@ pub fn gen_if_else_expr_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn, R>(
let phi = ctx.builder.build_phi(tv_ty, "").unwrap();
phi.add_incoming(&[(&tv, then_end_bb), (&ev, else_end_bb)]);
Some(phi.as_basic_value())
},
Some(phi.as_basic_value())
}
(Some(tv), None) => Some(tv.as_basic_value_enum()),
(None, Some(ev)) => Some(ev.as_basic_value_enum()),
(None, None) => None,
@ -919,11 +878,11 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
then_fn: ThenFn,
else_fn: ElseFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
where
G: CodeGenerator + ?Sized,
CondFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<IntValue<'ctx>, String>,
ThenFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
ElseFn: FnOnce(&mut G, &mut CodeGenContext<'ctx, 'a>) -> Result<(), String>,
{
gen_if_else_expr_callback(
generator,
@ -936,7 +895,7 @@ pub fn gen_if_callback<'ctx, 'a, G, CondFn, ThenFn, ElseFn>(
|generator, ctx| {
else_fn(generator, ctx)?;
Ok(None)
}
},
)?;
Ok(())
@ -948,9 +907,7 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else {
unreachable!()
};
let StmtKind::If { test, body, orelse, .. } = &stmt.node else { unreachable!() };
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
@ -969,9 +926,9 @@ pub fn gen_if<G: CodeGenerator>(
};
ctx.builder.build_unconditional_branch(test_bb).unwrap();
ctx.builder.position_at_end(test_bb);
let test = generator
.gen_expr(ctx, test)
.and_then(|v| v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose())?;
let test = generator.gen_expr(ctx, test).and_then(|v| {
v.map(|v| v.to_basic_value_enum(ctx, generator, test.custom.unwrap())).transpose()
})?;
if let Some(BasicValueEnum::IntValue(test)) = test {
ctx.builder
.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb)
@ -1077,16 +1034,16 @@ pub fn exn_constructor<'ctx>(
};
let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read();
let TopLevelDef::Class { name: zelf_name, .. } = &*def else {
unreachable!()
};
let TopLevelDef::Class { name: zelf_name, .. } = &*def else { unreachable!() };
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id").unwrap();
let id = ctx.resolver.get_string_id(&exception_name);
ctx.builder.build_store(id_ptr, int32.const_int(id as u64, false)).unwrap();
let empty_string = ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str);
let ptr = ctx.builder
let empty_string =
ctx.gen_const(generator, &Constant::Str(String::new()), ctx.primitives.str);
let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(5, false)], "exn.msg")
.unwrap();
let msg = if args.is_empty() {
@ -1101,21 +1058,24 @@ pub fn exn_constructor<'ctx>(
} else {
args.remove(0).1.to_basic_value_enum(ctx, generator, ctx.primitives.int64)?
};
let ptr = ctx.builder
let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.param")
.unwrap();
ctx.builder.build_store(ptr, value).unwrap();
}
// set file, func to empty string
for i in &[1, 4] {
let ptr = ctx.builder
let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.str")
.unwrap();
ctx.builder.build_store(ptr, empty_string.unwrap()).unwrap();
}
// set ints to zero
for i in &[2, 3] {
let ptr = ctx.builder
let ptr = ctx
.builder
.build_in_bounds_gep(zelf, &[zero, int32.const_int(*i, false)], "exn.ints")
.unwrap();
ctx.builder.build_store(ptr, zero).unwrap();
@ -1139,23 +1099,27 @@ pub fn gen_raise<'ctx, G: CodeGenerator + ?Sized>(
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let exception = exception.into_pointer_value();
let file_ptr = ctx.builder
let file_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(1, false)], "file_ptr")
.unwrap();
let filename = ctx.gen_string(generator, loc.file.0);
ctx.builder.build_store(file_ptr, filename).unwrap();
let row_ptr = ctx.builder
let row_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(2, false)], "row_ptr")
.unwrap();
ctx.builder.build_store(row_ptr, int32.const_int(loc.row as u64, false)).unwrap();
let col_ptr = ctx.builder
let col_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(3, false)], "col_ptr")
.unwrap();
ctx.builder.build_store(col_ptr, int32.const_int(loc.column as u64, false)).unwrap();
let current_fun = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let fun_name = ctx.gen_string(generator, current_fun.get_name().to_str().unwrap());
let name_ptr = ctx.builder
let name_ptr = ctx
.builder
.build_in_bounds_gep(exception, &[zero, int32.const_int(4, false)], "name_ptr")
.unwrap();
ctx.builder.build_store(name_ptr, fun_name).unwrap();
@ -1204,7 +1168,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
let mut final_data = None;
let has_cleanup = !finalbody.is_empty();
if has_cleanup {
let final_state = generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?;
let final_state =
generator.gen_var_alloc(ctx, ptr_type.into(), Some("try.final_state.addr"))?;
final_data = Some((final_state, Vec::new(), Vec::new()));
if let Some((continue_target, break_target)) = ctx.loop_target {
let break_proxy = ctx.ctx.append_basic_block(current_fun, "try.break");
@ -1219,8 +1184,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
} else {
let return_target = ctx.ctx.append_basic_block(current_fun, "try.return_target");
ctx.builder.position_at_end(return_target);
let return_value = ctx.return_buffer
.map(|v| ctx.builder.build_load(v, "$ret").unwrap());
let return_value =
ctx.return_buffer.map(|v| ctx.builder.build_load(v, "$ret").unwrap());
ctx.builder.build_return(return_value.as_ref().map(|v| v as &dyn BasicValue)).unwrap();
ctx.builder.position_at_end(current_block);
final_proxy(ctx, return_target, return_proxy, final_data.as_mut().unwrap());
@ -1250,11 +1215,12 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
&mut ctx.unifier,
type_.custom.unwrap(),
);
let obj_id = if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id
} else {
unreachable!()
};
let obj_id =
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(type_.custom.unwrap()) {
*obj_id
} else {
unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(obj_id.0), exn_name);
let exn_id = ctx.resolver.get_string_id(&exception_name);
let exn_id_global =
@ -1303,16 +1269,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
// run end_catch before continue/break/return
let mut final_proxy_lambda =
|ctx: &mut CodeGenContext<'ctx, 'a>,
target: BasicBlock<'ctx>,
block: BasicBlock<'ctx>| final_proxy(ctx, target, block, final_data.as_mut().unwrap());
let mut redirect_lambda = |ctx: &mut CodeGenContext<'ctx, 'a>,
target: BasicBlock<'ctx>,
block: BasicBlock<'ctx>| {
ctx.builder.position_at_end(block);
ctx.builder.build_unconditional_branch(target).unwrap();
ctx.builder.position_at_end(body);
};
|ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| {
final_proxy(ctx, target, block, final_data.as_mut().unwrap())
};
let mut redirect_lambda =
|ctx: &mut CodeGenContext<'ctx, 'a>, target: BasicBlock<'ctx>, block: BasicBlock<'ctx>| {
ctx.builder.position_at_end(block);
ctx.builder.build_unconditional_branch(target).unwrap();
ctx.builder.position_at_end(body);
};
let redirect = if has_cleanup {
&mut final_proxy_lambda
as &mut dyn FnMut(&mut CodeGenContext<'ctx, 'a>, BasicBlock<'ctx>, BasicBlock<'ctx>)
@ -1357,12 +1322,9 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx.builder.position_at_end(dispatcher);
unsafe {
let zero = ctx.ctx.i32_type().const_zero();
let exnid_ptr = ctx.builder
.build_gep(
exn.as_basic_value().into_pointer_value(),
&[zero, zero],
"exnidptr",
)
let exnid_ptr = ctx
.builder
.build_gep(exn.as_basic_value().into_pointer_value(), &[zero, zero], "exnidptr")
.unwrap();
Some(ctx.builder.build_load(exnid_ptr, "exnid").unwrap())
}
@ -1388,15 +1350,15 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
post_handlers.push(current);
ctx.builder.position_at_end(dispatcher_end);
if let Some(exn_type) = exn_type {
let dispatcher_cont =
ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont");
let dispatcher_cont = ctx.ctx.append_basic_block(current_fun, "try.dispatcher_cont");
let actual_id = exnid.unwrap().into_int_value();
let expected_id = ctx
.builder
.build_load(exn_type.into_pointer_value(), "expected_id")
.map(BasicValueEnum::into_int_value)
.unwrap();
let result = ctx.builder
let result = ctx
.builder
.build_int_compare(IntPredicate::EQ, actual_id, expected_id, "exncheck")
.unwrap();
ctx.builder.build_conditional_branch(result, handler_bb, dispatcher_cont).unwrap();
@ -1522,11 +1484,9 @@ pub fn gen_return<G: CodeGenerator>(
let func = ctx.builder.get_insert_block().and_then(BasicBlock::get_parent).unwrap();
let value = if let Some(v_expr) = value.as_ref() {
if let Some(v) = generator.gen_expr(ctx, v_expr).transpose() {
Some(
v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))?
)
Some(v.and_then(|v| v.to_basic_value_enum(ctx, generator, v_expr.custom.unwrap()))?)
} else {
return Ok(())
return Ok(());
}
} else {
None
@ -1554,7 +1514,8 @@ pub fn gen_return<G: CodeGenerator>(
generator.bool_to_i1(ctx, ret_val)
} else {
ret_val
}.into()
}
.into()
} else {
ret_val
}
@ -1592,16 +1553,12 @@ pub fn gen_stmt<G: CodeGenerator>(
}
StmtKind::AnnAssign { target, value, .. } => {
if let Some(value) = value {
let Some(value) = generator.gen_expr(ctx, value)? else {
return Ok(())
};
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
generator.gen_assign(ctx, target, value)?;
}
}
StmtKind::Assign { targets, value, .. } => {
let Some(value) = generator.gen_expr(ctx, value)? else {
return Ok(())
};
let Some(value) = generator.gen_expr(ctx, value)? else { return Ok(()) };
for target in targets {
generator.gen_assign(ctx, target, value.clone())?;
}
@ -1626,7 +1583,7 @@ pub fn gen_stmt<G: CodeGenerator>(
let exc = if let Some(v) = generator.gen_expr(ctx, exc)? {
v.to_basic_value_enum(ctx, generator, exc.custom.unwrap())?
} else {
return Ok(())
return Ok(());
};
gen_raise(generator, ctx, Some(&exc), stmt.location);
} else {
@ -1637,14 +1594,16 @@ pub fn gen_stmt<G: CodeGenerator>(
let test = if let Some(v) = generator.gen_expr(ctx, test)? {
v.to_basic_value_enum(ctx, generator, test.custom.unwrap())?
} else {
return Ok(())
return Ok(());
};
let err_msg = match msg {
Some(msg) => if let Some(v) = generator.gen_expr(ctx, msg)? {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
} else {
return Ok(())
},
Some(msg) => {
if let Some(v) = generator.gen_expr(ctx, msg)? {
v.to_basic_value_enum(ctx, generator, msg.custom.unwrap())?
} else {
return Ok(());
}
}
None => ctx.gen_string(generator, ""),
};
ctx.make_assert_impl(
@ -1656,7 +1615,7 @@ pub fn gen_stmt<G: CodeGenerator>(
stmt.location,
);
}
_ => unimplemented!()
_ => unimplemented!(),
};
Ok(())
}

View File

@ -1,13 +1,14 @@
use crate::{
codegen::{
classes::{ListType, NDArrayType, ProxyType, RangeType},
concrete_type::ConcreteTypeStore, CodeGenContext, CodeGenerator, CodeGenLLVMOptions,
CodeGenTargetMachineOptions, CodeGenTask, DefaultCodeGenerator, WithCall, WorkerRegistry,
concrete_type::ConcreteTypeStore,
CodeGenContext, CodeGenLLVMOptions, CodeGenTargetMachineOptions, CodeGenTask,
CodeGenerator, DefaultCodeGenerator, WithCall, WorkerRegistry,
},
symbol_resolver::{SymbolResolver, ValueEnum},
toplevel::{
composer::{ComposerConfig, TopLevelComposer}, DefinitionId, FunInstance, TopLevelContext,
TopLevelDef,
composer::{ComposerConfig, TopLevelComposer},
DefinitionId, FunInstance, TopLevelContext, TopLevelDef,
},
typecheck::{
type_inferencer::{FunctionData, Inferencer, PrimitiveStore},
@ -17,7 +18,7 @@ use crate::{
use indoc::indoc;
use inkwell::{
targets::{InitializationConfig, Target},
OptimizationLevel
OptimizationLevel,
};
use nac3parser::{
ast::{fold::Fold, StrRef},
@ -70,9 +71,7 @@ impl SymbolResolver for Resolver {
.read()
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from([
format!("cannot find symbol `{}`", id),
]))
.ok_or_else(|| HashSet::from([format!("cannot find symbol `{}`", id)]))
}
fn get_string_id(&self, _: &str) -> i32 {
@ -227,12 +226,7 @@ fn test_primitives() {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}
@ -417,12 +411,7 @@ fn test_simple_call() {
opt_level: OptimizationLevel::Default,
target: CodeGenTargetMachineOptions::from_host_triple(),
};
let (registry, handles) = WorkerRegistry::create_workers(
threads,
top_level,
&llvm_options,
&f
);
let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f);
registry.add_task(task);
registry.wait_tasks_complete(handles);
}

View File

@ -1,18 +1,18 @@
use std::fmt::Debug;
use std::rc::Rc;
use std::sync::Arc;
use std::{collections::HashMap, collections::HashSet, fmt::Display};
use std::rc::Rc;
use crate::{
codegen::{CodeGenContext, CodeGenerator},
toplevel::{DefinitionId, TopLevelDef, type_annotation::TypeAnnotation},
toplevel::{type_annotation::TypeAnnotation, DefinitionId, TopLevelDef},
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use inkwell::values::{BasicValueEnum, FloatValue, IntValue, PointerValue, StructValue};
use itertools::{chain, Itertools, izip};
use itertools::{chain, izip, Itertools};
use nac3parser::ast::{Constant, Expr, Location, StrRef};
use parking_lot::RwLock;
@ -39,7 +39,7 @@ impl SymbolValue {
constant: &Constant,
expected_ty: Type,
primitives: &PrimitiveStore,
unifier: &mut Unifier
unifier: &mut Unifier,
) -> Result<Self, String> {
match constant {
Constant::None => {
@ -62,24 +62,16 @@ impl SymbolValue {
} else {
Err(format!("Expected {expected_ty:?}, but got str"))
}
},
}
Constant::Int(i) => {
if unifier.unioned(expected_ty, primitives.int32) {
i32::try_from(*i)
.map(SymbolValue::I32)
.map_err(|e| e.to_string())
i32::try_from(*i).map(SymbolValue::I32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.int64) {
i64::try_from(*i)
.map(SymbolValue::I64)
.map_err(|e| e.to_string())
i64::try_from(*i).map(SymbolValue::I64).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint32) {
u32::try_from(*i)
.map(SymbolValue::U32)
.map_err(|e| e.to_string())
u32::try_from(*i).map(SymbolValue::U32).map_err(|e| e.to_string())
} else if unifier.unioned(expected_ty, primitives.uint64) {
u64::try_from(*i)
.map(SymbolValue::U64)
.map_err(|e| e.to_string())
u64::try_from(*i).map(SymbolValue::U64).map_err(|e| e.to_string())
} else {
Err(format!("Expected {}, but got int", unifier.stringify(expected_ty)))
}
@ -87,7 +79,10 @@ impl SymbolValue {
Constant::Tuple(t) => {
let expected_ty = unifier.get_ty(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty.as_ref() else {
return Err(format!("Expected {:?}, but got Tuple", expected_ty.get_type_name()))
return Err(format!(
"Expected {:?}, but got Tuple",
expected_ty.get_type_name()
));
};
assert_eq!(ty.len(), t.len());
@ -105,7 +100,7 @@ impl SymbolValue {
} else {
Err(format!("Expected {expected_ty:?}, but got float"))
}
},
}
_ => Err(format!("Unsupported value type {constant:?}")),
}
}
@ -113,9 +108,7 @@ impl SymbolValue {
/// Creates a [`SymbolValue`] from a [`Constant`], with its type being inferred from the constant value.
///
/// * `constant` - The constant to create the value from.
pub fn from_constant_inferred(
constant: &Constant,
) -> Result<Self, String> {
pub fn from_constant_inferred(constant: &Constant) -> Result<Self, String> {
match constant {
Constant::None => Ok(SymbolValue::OptionNone),
Constant::Bool(b) => Ok(SymbolValue::Bool(*b)),
@ -123,13 +116,19 @@ impl SymbolValue {
Constant::Int(i) => {
let i = *i;
if i >= 0 {
i32::try_from(i).map(SymbolValue::I32)
i32::try_from(i)
.map(SymbolValue::I32)
.or_else(|_| i64::try_from(i).map(SymbolValue::I64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
} else {
u32::try_from(i).map(SymbolValue::U32)
u32::try_from(i)
.map(SymbolValue::U32)
.or_else(|_| u64::try_from(i).map(SymbolValue::U64))
.map_err(|_| format!("Literal cannot be expressed as any integral type: {i}"))
.map_err(|_| {
format!("Literal cannot be expressed as any integral type: {i}")
})
}
}
Constant::Tuple(t) => {
@ -155,20 +154,19 @@ impl SymbolValue {
SymbolValue::Double(_) => primitives.float,
SymbolValue::Bool(_) => primitives.bool,
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
.map(|v| v.get_type(primitives, unifier))
.collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple {
ty: vs_tys,
})
let vs_tys = vs.iter().map(|v| v.get_type(primitives, unifier)).collect::<Vec<_>>();
unifier.add_ty(TypeEnum::TTuple { ty: vs_tys })
}
SymbolValue::OptionSome(_) | SymbolValue::OptionNone => primitives.option,
}
}
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
pub fn get_type_annotation(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> TypeAnnotation {
match self {
SymbolValue::Bool(..)
| SymbolValue::Double(..)
@ -199,7 +197,11 @@ impl SymbolValue {
}
/// Returns the [`TypeEnum`] representing the data type of this value.
pub fn get_type_enum(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> Rc<TypeEnum> {
pub fn get_type_enum(
&self,
primitives: &PrimitiveStore,
unifier: &mut Unifier,
) -> Rc<TypeEnum> {
let ty = self.get_type(primitives, unifier);
unifier.get_ty(ty)
}
@ -332,7 +334,6 @@ impl<'ctx> From<StructValue<'ctx>> for ValueEnum<'ctx> {
}
impl<'ctx> ValueEnum<'ctx> {
/// Converts this [`ValueEnum`] to a [`BasicValueEnum`].
pub fn to_basic_value_enum<'a>(
self,
@ -374,7 +375,7 @@ pub trait SymbolResolver {
&self,
_unifier: &mut Unifier,
_top_level_defs: &[Arc<RwLock<TopLevelDef>>],
_primitives: &PrimitiveStore
_primitives: &PrimitiveStore,
) -> Result<(), String> {
Ok(())
}
@ -443,40 +444,29 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if !type_vars.is_empty() {
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
),
]))
return Err(HashSet::from([format!(
"Unexpected number of type parameters: expected {} but got 0",
type_vars.len()
)]));
}
let fields = chain(
fields.iter().map(|(k, v, m)| (*k, (*v, *m))),
methods.iter().map(|(k, v, _)| (*k, (*v, false))),
)
.collect();
Ok(unifier.add_ty(TypeEnum::TObj {
obj_id,
fields,
params: VarMap::default(),
}))
.collect();
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: VarMap::default() }))
} else {
Err(HashSet::from([
format!("Cannot use function name as type at {loc}"),
]))
Err(HashSet::from([format!("Cannot use function name as type at {loc}")]))
}
} else {
let ty = resolver
.get_symbol_type(unifier, top_level_defs, primitives, *id)
.map_err(|e| HashSet::from([
format!("Unknown type annotation at {loc}: {e}"),
]))?;
let ty =
resolver.get_symbol_type(unifier, top_level_defs, primitives, *id).map_err(
|e| HashSet::from([format!("Unknown type annotation at {loc}: {e}")]),
)?;
if let TypeEnum::TVar { .. } = &*unifier.get_ty(ty) {
Ok(ty)
} else {
Err(HashSet::from([
format!("Unknown type annotation {id} at {loc}"),
]))
Err(HashSet::from([format!("Unknown type annotation {id} at {loc}")]))
}
}
}
@ -499,9 +489,7 @@ pub fn parse_type_annotation<T>(
.collect::<Result<Vec<_>, _>>()?;
Ok(unifier.add_ty(TypeEnum::TTuple { ty }))
} else {
Err(HashSet::from([
"Expected multiple elements for tuple".into()
]))
Err(HashSet::from(["Expected multiple elements for tuple".into()]))
}
} else if *id == literal_id {
let mut parse_literal = |elt: &Expr<T>| {
@ -509,19 +497,21 @@ pub fn parse_type_annotation<T>(
let ty_enum = &*unifier.get_ty_immutable(ty);
match ty_enum {
TypeEnum::TLiteral { values, .. } => Ok(values.clone()),
_ => Err(HashSet::from([
format!("Expected literal in type argument for Literal at {}", elt.location),
]))
_ => Err(HashSet::from([format!(
"Expected literal in type argument for Literal at {}",
elt.location
)])),
}
};
let values = if let Tuple { elts, .. } = &slice.node {
elts.iter()
.map(&mut parse_literal)
.collect::<Result<Vec<_>, _>>()?
elts.iter().map(&mut parse_literal).collect::<Result<Vec<_>, _>>()?
} else {
vec![parse_literal(slice)?]
}.into_iter().flatten().collect_vec();
}
.into_iter()
.flatten()
.collect_vec();
Ok(unifier.get_fresh_literal(values, Some(slice.location)))
} else {
@ -539,13 +529,11 @@ pub fn parse_type_annotation<T>(
let def = top_level_defs[obj_id.0].read();
if let TopLevelDef::Class { fields, methods, type_vars, .. } = &*def {
if types.len() != type_vars.len() {
return Err(HashSet::from([
format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
),
]))
return Err(HashSet::from([format!(
"Unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
types.len()
)]));
}
let mut subst = VarMap::new();
for (var, ty) in izip!(type_vars.iter(), types.iter()) {
@ -569,9 +557,7 @@ pub fn parse_type_annotation<T>(
}));
Ok(unifier.add_ty(TypeEnum::TObj { obj_id, fields, params: subst }))
} else {
Err(HashSet::from([
"Cannot use function name as type".into(),
]))
Err(HashSet::from(["Cannot use function name as type".into()]))
}
}
};
@ -582,17 +568,13 @@ pub fn parse_type_annotation<T>(
if let Name { id, .. } = &value.node {
subscript_name_handle(id, slice, unifier)
} else {
Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
]))
Err(HashSet::from([format!("unsupported type expression at {}", expr.location)]))
}
}
Constant { value, .. } => SymbolValue::from_constant_inferred(value)
.map(|v| unifier.get_fresh_literal(vec![v], Some(expr.location)))
.map_err(|err| HashSet::from([err])),
_ => Err(HashSet::from([
format!("unsupported type expression at {}", expr.location),
])),
_ => Err(HashSet::from([format!("unsupported type expression at {}", expr.location)])),
}
}

File diff suppressed because it is too large Load Diff

View File

@ -82,7 +82,8 @@ impl TopLevelComposer {
let mut builtin_id = HashMap::default();
let mut builtin_ty = HashMap::default();
let builtin_name_list = definition_ast_list.iter()
let builtin_name_list = definition_ast_list
.iter()
.map(|def_ast| match *def_ast.0.read() {
TopLevelDef::Class { name, .. } => name.to_string(),
TopLevelDef::Function { simple_name, .. } => simple_name.to_string(),
@ -93,19 +94,24 @@ impl TopLevelComposer {
let name = (**name).into();
let def = definition_ast_list[id].0.read();
if let TopLevelDef::Function { name: func_name, simple_name, signature, .. } = &*def {
assert_eq!(name, *simple_name, "Simple name of builtin function should match builtin name list");
assert_eq!(
name, *simple_name,
"Simple name of builtin function should match builtin name list"
);
// Do not add member functions into the list of builtin IDs;
// Here we assume that all builtin top-level functions have the same name and simple
// name, and all member functions have something prefixed to its name
if *func_name != simple_name.to_string() {
continue
continue;
}
builtin_ty.insert(name, *signature);
builtin_id.insert(name, DefinitionId(id));
} else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def
{
assert_eq!(id, object_id.0, "Object id of class '{name}' should match its index in builtin name list");
} else if let TopLevelDef::Class { name, constructor, object_id, .. } = &*def {
assert_eq!(
id, object_id.0,
"Object id of class '{name}' should match its index in builtin name list"
);
if let Some(constructor) = constructor {
builtin_ty.insert(*name, *constructor);
}
@ -384,9 +390,9 @@ impl TopLevelComposer {
let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast else {
let Some(ast::Located { node: ast::StmtKind::ClassDef { bases, .. }, .. }) =
class_ast
else {
unreachable!()
};
@ -415,12 +421,10 @@ impl TopLevelComposer {
} =>
{
if is_generic {
return Err(HashSet::from([
format!(
"only single Generic[...] is allowed (at {})",
b.location
),
]))
return Err(HashSet::from([format!(
"only single Generic[...] is allowed (at {})",
b.location
)]));
}
is_generic = true;
@ -459,12 +463,10 @@ impl TopLevelComposer {
})
};
if !all_unique_type_var {
return Err(HashSet::from([
format!(
"duplicate type variable occurs (at {})",
slice.location
),
]))
return Err(HashSet::from([format!(
"duplicate type variable occurs (at {})",
slice.location
)]));
}
// add to TopLevelDef
@ -487,7 +489,7 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
Ok(())
}
@ -514,9 +516,9 @@ impl TopLevelComposer {
} = &mut *class_def
{
let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. },
..
}) = class_ast else {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
else {
unreachable!()
};
@ -543,13 +545,11 @@ impl TopLevelComposer {
}
if has_base {
return Err(HashSet::from([
format!(
"a class definition can only have at most one base class \
return Err(HashSet::from([format!(
"a class definition can only have at most one base class \
declaration and one generic declaration (at {})",
b.location
),
]))
b.location
)]));
}
has_base = true;
@ -567,12 +567,10 @@ impl TopLevelComposer {
if let TypeAnnotation::CustomClass { .. } = &base_ty {
class_ancestors.push(base_ty);
} else {
return Err(HashSet::from([
format!(
"class base declaration can only be custom class (at {})",
b.location,
),
]))
return Err(HashSet::from([format!(
"class base declaration can only be custom class (at {})",
b.location,
)]));
}
}
Ok(())
@ -589,31 +587,35 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
// second, get all ancestors
let mut ancestors_store: HashMap<DefinitionId, Vec<TypeAnnotation>> = HashMap::default();
let mut get_all_ancestors = |class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> {
let class_def = class_def.read();
let (class_ancestors, class_id) = {
if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
(ancestors, *object_id)
} else {
return Ok(());
}
let mut get_all_ancestors =
|class_def: &Arc<RwLock<TopLevelDef>>| -> Result<(), HashSet<String>> {
let class_def = class_def.read();
let (class_ancestors, class_id) = {
if let TopLevelDef::Class { ancestors, object_id, .. } = &*class_def {
(ancestors, *object_id)
} else {
return Ok(());
}
};
ancestors_store.insert(
class_id,
// if class has direct parents, get all ancestors of its parents. Else just empty
if class_ancestors.is_empty() {
vec![]
} else {
Self::get_all_ancestors_helper(
&class_ancestors[0],
temp_def_list.as_slice(),
)?
},
);
Ok(())
};
ancestors_store.insert(
class_id,
// if class has direct parents, get all ancestors of its parents. Else just empty
if class_ancestors.is_empty() {
vec![]
} else {
Self::get_all_ancestors_helper(&class_ancestors[0], temp_def_list.as_slice())?
},
);
Ok(())
};
for (class_def, ast) in self.definition_ast_list.iter().skip(self.builtin_num) {
if ast.is_none() {
continue;
@ -623,7 +625,7 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
// insert the ancestors to the def list
@ -633,8 +635,7 @@ impl TopLevelComposer {
}
let mut class_def = class_def.write();
let (class_ancestors, class_id, class_type_vars) = {
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } =
&mut *class_def
if let TopLevelDef::Class { ancestors, object_id, type_vars, .. } = &mut *class_def
{
(ancestors, *object_id, type_vars)
} else {
@ -665,8 +666,9 @@ impl TopLevelComposer {
ast::StmtKind::FunctionDef { .. } | ast::StmtKind::AnnAssign { .. }
) {
return Err(HashSet::from([
"Classes inherited from exception should have no custom fields/methods".into()
]))
"Classes inherited from exception should have no custom fields/methods"
.into(),
]));
}
}
}
@ -674,7 +676,8 @@ impl TopLevelComposer {
// deal with ancestor of Exception object
let TopLevelDef::Class { name, ancestors, object_id, .. } =
&mut *self.definition_ast_list[7].0.write() else {
&mut *self.definition_ast_list[7].0.write()
else {
unreachable!()
};
@ -713,7 +716,7 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
// handle the inherited methods and fields
@ -758,9 +761,14 @@ impl TopLevelComposer {
let mut subst_list = Some(Vec::new());
// unification of previously assigned typevar
let mut unification_helper = |ty, def| -> Result<(), HashSet<String>> {
let target_ty =
get_type_from_type_annotation_kinds(&temp_def_list, unifier, &def, &mut subst_list)?;
unifier.unify(ty, target_ty)
let target_ty = get_type_from_type_annotation_kinds(
&temp_def_list,
unifier,
&def,
&mut subst_list,
)?;
unifier
.unify(ty, target_ty)
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
Ok(())
};
@ -793,14 +801,16 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
for (def, _) in def_ast_list.iter().skip(self.builtin_num) {
match &*def.read() {
TopLevelDef::Class { resolver: Some(resolver), .. }
| TopLevelDef::Function { resolver: Some(resolver), .. } => {
if let Err(e) = resolver.handle_deferred_eval(unifier, &temp_def_list, primitives) {
if let Err(e) =
resolver.handle_deferred_eval(unifier, &temp_def_list, primitives)
{
errors.insert(e);
}
}
@ -828,7 +838,8 @@ impl TopLevelComposer {
return Ok(());
};
let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def else {
let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def
else {
// not top level function def, skip
return Ok(());
};
@ -857,25 +868,22 @@ impl TopLevelComposer {
"top level function must have unique parameter names \
and names should not be the same as the keywords (at {})",
x.location
),
]))
}}
)]));
}
}
let arg_with_default: Vec<(
&ast::Located<ast::ArgData<()>>,
Option<&ast::Expr>,
)> = args
.args
.iter()
.rev()
.zip(
args.defaults
.iter()
.rev()
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)),
)
.collect_vec();
let arg_with_default: Vec<(&ast::Located<ast::ArgData<()>>, Option<&ast::Expr>)> =
args.args
.iter()
.rev()
.zip(
args.defaults
.iter()
.rev()
.map(|x| -> Option<&ast::Expr> { Some(x) })
.chain(std::iter::repeat(None)),
)
.collect_vec();
arg_with_default
.iter()
@ -885,12 +893,12 @@ impl TopLevelComposer {
.node
.annotation
.as_ref()
.ok_or_else(|| HashSet::from([
format!(
.ok_or_else(|| {
HashSet::from([format!(
"function parameter `{}` needs type annotation at {}",
x.node.arg, x.location
),
]))?
)])
})?
.as_ref();
let type_annotation = parse_ast_to_type_annotation_kinds(
@ -926,7 +934,7 @@ impl TopLevelComposer {
temp_def_list.as_ref(),
unifier,
&type_annotation,
&mut None
&mut None,
)?;
Ok(FuncArg {
@ -935,18 +943,16 @@ impl TopLevelComposer {
default_value: match default {
None => None,
Some(default) => Some({
let v = Self::parse_parameter_default_value(
default, resolver,
)?;
let v = Self::parse_parameter_default_value(default, resolver)?;
Self::check_default_param_type(
&v,
&type_annotation,
primitives_store,
unifier,
)
.map_err(
|err| HashSet::from([format!("{} (at {})", err, x.location),
]))?;
.map_err(|err| {
HashSet::from([format!("{} (at {})", err, x.location)])
})?;
v
}),
},
@ -993,7 +999,7 @@ impl TopLevelComposer {
&temp_def_list,
unifier,
&return_ty_annotation,
&mut None
&mut None,
)?
} else {
primitives_store.none
@ -1016,9 +1022,9 @@ impl TopLevelComposer {
ret: return_ty,
vars: function_var_map,
}));
unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([
e.at(Some(function_ast.location)).to_display(unifier).to_string(),
]))?;
unifier.unify(*dummy_ty, function_ty).map_err(|e| {
HashSet::from([e.at(Some(function_ast.location)).to_display(unifier).to_string()])
})?;
Ok(())
};
for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) {
@ -1030,7 +1036,7 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
Ok(())
}
@ -1047,14 +1053,9 @@ impl TopLevelComposer {
let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write();
let TopLevelDef::Class {
object_id,
ancestors,
fields,
methods,
resolver,
type_vars,
..
} = &mut *class_def else {
object_id, ancestors, fields, methods, resolver, type_vars, ..
} = &mut *class_def
else {
unreachable!("here must be toplevel class def");
};
let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else {
@ -1375,14 +1376,9 @@ impl TopLevelComposer {
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> {
let TopLevelDef::Class {
object_id,
ancestors,
fields,
methods,
resolver,
type_vars,
..
} = class_def else {
object_id, ancestors, fields, methods, resolver, type_vars, ..
} = class_def
else {
unreachable!("here must be class def ast")
};
let (
@ -1414,9 +1410,7 @@ impl TopLevelComposer {
for (anc_method_name, anc_method_ty, anc_method_def_id) in methods {
// find if there is a method with same name in the child class
let mut to_be_added = (*anc_method_name, *anc_method_ty, *anc_method_def_id);
for (class_method_name, class_method_ty, class_method_defid) in
&*class_methods_def
{
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
if class_method_name == anc_method_name {
// ignore and handle self
// if is __init__ method, no need to check return type
@ -1430,27 +1424,20 @@ impl TopLevelComposer {
if !ok {
return Err(HashSet::from([format!(
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
]))
]));
}
// mark it as added
is_override.insert(*class_method_name);
to_be_added =
(*class_method_name, *class_method_ty, *class_method_defid);
to_be_added = (*class_method_name, *class_method_ty, *class_method_defid);
break;
}
}
new_child_methods.push(to_be_added);
}
// add those that are not overriding method to the new_child_methods
for (class_method_name, class_method_ty, class_method_defid) in
&*class_methods_def
{
for (class_method_name, class_method_ty, class_method_defid) in &*class_methods_def {
if !is_override.contains(class_method_name) {
new_child_methods.push((
*class_method_name,
*class_method_ty,
*class_method_defid,
));
new_child_methods.push((*class_method_name, *class_method_ty, *class_method_defid));
}
}
// use the new_child_methods to replace all the elements in `class_methods_def`
@ -1466,8 +1453,8 @@ impl TopLevelComposer {
for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name {
return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes"),
]))
"field `{class_field_name}` has already declared in the ancestor classes"
)]));
}
}
new_child_fields.push(to_be_added);
@ -1499,24 +1486,30 @@ impl TopLevelComposer {
// first, fix function typevar ids
// they may be changed with our use of placeholders
for (def, _) in definition_ast_list.iter().skip(self.builtin_num) {
if let TopLevelDef::Function {
signature,
var_id,
..
} = &mut *def.write() {
if let TopLevelDef::Function { signature, var_id, .. } = &mut *def.write() {
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
unifier.get_ty(*signature).as_ref() {
let new_var_ids = vars.values().map(|v| match &*unifier.get_ty(*v) {
TypeEnum::TVar{id, ..} => *id,
_ => unreachable!(),
}).collect_vec();
unifier.get_ty(*signature).as_ref()
{
let new_var_ids = vars
.values()
.map(|v| match &*unifier.get_ty(*v) {
TypeEnum::TVar { id, .. } => *id,
_ => unreachable!(),
})
.collect_vec();
if new_var_ids != *var_id {
let new_signature = FunSignature {
args: args.clone(),
ret: *ret,
vars: new_var_ids.iter().zip(vars.values()).map(|(id, v)| (*id, *v)).collect(),
vars: new_var_ids
.iter()
.zip(vars.values())
.map(|(id, v)| (*id, *v))
.collect(),
};
unifier.unification_table.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature)));
unifier
.unification_table
.set_value(*signature, Rc::new(TypeEnum::TFunc(new_signature)));
*var_id = new_var_ids;
}
}
@ -1542,7 +1535,7 @@ impl TopLevelComposer {
&def_list,
unifier,
&make_self_type_annotation(type_vars, *object_id),
&mut None
&mut None,
)?;
if ancestors
.iter()
@ -1590,9 +1583,12 @@ impl TopLevelComposer {
};
constructors.push((i, signature, definition_extension.len()));
definition_extension.push((Arc::new(RwLock::new(cons_fun)), None));
unifier.unify(constructor.unwrap(), signature).map_err(|e| HashSet::from([
e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string()
]))?;
unifier.unify(constructor.unwrap(), signature).map_err(|e| {
HashSet::from([e
.at(Some(ast.as_ref().unwrap().location))
.to_display(unifier)
.to_string()])
})?;
return Ok(());
}
let mut init_id: Option<DefinitionId> = None;
@ -1605,7 +1601,8 @@ impl TopLevelComposer {
init_id = Some(*id);
let func_ty_enum = unifier.get_ty(*func_sig);
let TypeEnum::TFunc(FunSignature { args, vars, .. }) =
func_ty_enum.as_ref() else {
func_ty_enum.as_ref()
else {
unreachable!("must be typeenum::tfunc")
};
@ -1620,9 +1617,12 @@ impl TopLevelComposer {
ret: self_type,
vars: contor_type_vars,
}));
unifier.unify(constructor.unwrap(), contor_type).map_err(|e| HashSet::from([
e.at(Some(ast.as_ref().unwrap().location)).to_display(unifier).to_string()
]))?;
unifier.unify(constructor.unwrap(), contor_type).map_err(|e| {
HashSet::from([e
.at(Some(ast.as_ref().unwrap().location))
.to_display(unifier)
.to_string()])
})?;
// class field instantiation check
if let (Some(init_id), false) = (init_id, fields.is_empty()) {
@ -1641,7 +1641,7 @@ impl TopLevelComposer {
class_name,
body[0].location,
),
]))
]));
}
}
}
@ -1658,11 +1658,12 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
for (i, signature, id) in constructors {
let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() else {
let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write()
else {
unreachable!()
};
@ -1697,8 +1698,8 @@ impl TopLevelComposer {
} = &mut *function_def
{
let signature_ty_enum = unifier.get_ty(*signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
signature_ty_enum.as_ref() else {
let TypeEnum::TFunc(FunSignature { args, ret, vars }) = signature_ty_enum.as_ref()
else {
unreachable!("must be typeenum::tfunc")
};
@ -1714,10 +1715,7 @@ impl TopLevelComposer {
let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds(
&def_list,
unifier,
&ty_ann,
&mut None
&def_list, unifier, &ty_ann, &mut None,
)?;
vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
@ -1739,7 +1737,9 @@ impl TopLevelComposer {
.values()
.map(|ty| {
unifier.get_instantiations(*ty).unwrap_or_else(|| {
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else {
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } =
&*unifier.get_ty(*ty)
else {
unreachable!()
};
@ -1779,8 +1779,7 @@ impl TopLevelComposer {
let class_ty_var_ids = type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x)
{
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
*id
} else {
unreachable!("must be type var here");
@ -1839,7 +1838,8 @@ impl TopLevelComposer {
};
let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node else {
ast.clone().unwrap().node
else {
unreachable!("must be function def ast")
};
if !decorator_list.is_empty()
@ -1857,13 +1857,12 @@ impl TopLevelComposer {
continue;
}
let fun_body = body
let fun_body = body
.into_iter()
.map(|b| inferencer.fold_stmt(b))
.collect::<Result<Vec<_>, _>>()?;
let returned =
inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
let returned = inferencer.check_block(fun_body.as_slice(), &mut identifiers)?;
{
// check virtuals
let defs = ctx.definitions.read();
@ -1873,9 +1872,9 @@ impl TopLevelComposer {
if let TypeEnum::TObj { obj_id, .. } = &*base {
*obj_id
} else {
return Err(HashSet::from([
format!("Base type should be a class (at {loc})"),
]))
return Err(HashSet::from([format!(
"Base type should be a class (at {loc})"
)]));
}
};
let subtype_id = {
@ -1887,7 +1886,7 @@ impl TopLevelComposer {
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]))
]));
}
};
let subtype_entry = defs[subtype_id.0].read();
@ -1902,7 +1901,7 @@ impl TopLevelComposer {
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]))
]));
}
}
}
@ -1912,7 +1911,9 @@ impl TopLevelComposer {
inst_ret,
&mut |id| {
let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read()
else { unreachable!("must be class id here") };
else {
unreachable!("must be class id here")
};
name.to_string()
},
@ -1924,11 +1925,16 @@ impl TopLevelComposer {
ret_str,
name,
ast.as_ref().unwrap().location
),]))
)]));
}
instance_to_stmt.insert(
get_subst_key(unifier, self_type, &subst, Some(&vars.keys().copied().collect())),
get_subst_key(
unifier,
self_type,
&subst,
Some(&vars.keys().copied().collect()),
),
FunInstance {
body: Arc::new(fun_body),
unifier_id: 0,
@ -1950,7 +1956,7 @@ impl TopLevelComposer {
}
}
if !errors.is_empty() {
return Err(errors)
return Err(errors);
}
Ok(())
}

View File

@ -4,75 +4,270 @@ use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::unpack_ndarray_var_tys;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
use strum::IntoEnumIterator;
use strum_macros::EnumIter;
use super::*;
/// Structure storing [`DefinitionId`] for primitive types.
#[derive(Clone, Copy)]
pub struct PrimitiveDefinitionIds {
pub int32: DefinitionId,
pub int64: DefinitionId,
pub uint32: DefinitionId,
pub uint64: DefinitionId,
pub float: DefinitionId,
pub bool: DefinitionId,
pub none: DefinitionId,
pub range: DefinitionId,
pub str: DefinitionId,
pub exception: DefinitionId,
pub option: DefinitionId,
pub ndarray: DefinitionId,
/// All primitive types and functions in nac3core.
#[derive(Clone, Copy, Debug, EnumIter, PartialEq, Eq)]
pub enum PrimDef {
Int32,
Int64,
Float,
Bool,
None,
Range,
Str,
Exception,
UInt32,
UInt64,
Option,
OptionIsSome,
OptionIsNone,
OptionUnwrap,
NDArray,
NDArrayCopy,
derppening marked this conversation as resolved Outdated

This doesn't appear to be the entire list of primitive definitions. What about the other Numpy functions?

Also, I would prefer not to comment the ID itself here. Rather, I'd like a test case that asserts that each primitive class/function has the same ID as its index in the Vec returned by get_builtins.

This doesn't appear to be the entire list of primitive definitions. What about the other Numpy functions? Also, I would prefer not to comment the ID itself here. Rather, I'd like a test case that asserts that each primitive class/function has the same ID as its index in the `Vec` returned by `get_builtins`.
Outdated
Review

This doesn't appear to be the entire list of primitive definitions. What about the other Numpy functions?

I will revise on this.

Also, I would prefer not to comment the ID itself here.

I commented the IDs here since I find it very helpful when printing {:?}s and doing quick lookups to see what is what. Still should I remove them?

Rather, I'd like a test case that asserts that each primitive class/function has the same ID as its index in the Vec returned by get_builtins.

Ok, I will work on this now.

> This doesn't appear to be the entire list of primitive definitions. What about the other Numpy functions? I will revise on this. > Also, I would prefer not to comment the ID itself here. I commented the IDs here since I find it very helpful when printing `{:?}`s and doing quick lookups to see what is what. Still should I remove them? > Rather, I'd like a test case that asserts that each primitive class/function has the same ID as its index in the Vec returned by get_builtins. Ok, I will work on this now.
NDArrayFill,
FunInt32,
FunInt64,
FunUInt32,
FunUInt64,
FunFloat,
FunNpNDArray,
FunNpEmpty,
FunNpZeros,
FunNpOnes,
FunNpFull,
FunNpArray,
FunNpEye,
FunNpIdentity,
FunRound,
FunRound64,
FunNpRound,
FunRange,
FunStr,
FunBool,
FunFloor,
FunFloor64,
FunNpFloor,
FunCeil,
FunCeil64,
FunNpCeil,
FunLen,
FunMin,
FunNpMin,
FunNpMinimum,
FunMax,
FunNpMax,
FunNpMaximum,
FunAbs,
FunNpIsNan,
FunNpIsInf,
FunNpSin,
FunNpCos,
FunNpExp,
FunNpExp2,
FunNpLog,
FunNpLog10,
FunNpLog2,
FunNpFabs,
FunNpSqrt,
FunNpRint,
FunNpTan,
FunNpArcsin,
FunNpArccos,
FunNpArctan,
FunNpSinh,
FunNpCosh,
FunNpTanh,
FunNpArcsinh,
FunNpArccosh,
FunNpArctanh,
FunNpExpm1,
FunNpCbrt,
FunSpSpecErf,
FunSpSpecErfc,
FunSpSpecGamma,
FunSpSpecGammaln,
FunSpSpecJ0,
FunSpSpecJ1,
FunNpArctan2,
FunNpCopysign,
FunNpFmax,
FunNpFmin,
FunNpLdExp,
FunNpHypot,
FunNpNextAfter,
FunSome,
}
impl PrimitiveDefinitionIds {
/// Returns all [`DefinitionId`] of primitives as a [`Vec`].
/// Associated details of a [`PrimDef`]
pub enum PrimDefDetails {
PrimFunction { name: &'static str, simple_name: &'static str },
PrimClass { name: &'static str },
}
impl PrimDef {
/// Get the assigned [`DefinitionId`] of this [`PrimDef`].
///
/// There are no guarantees on ordering of the IDs.
/// The assigned definition ID is defined by the position this [`PrimDef`] enum unit variant is defined at,
/// with the first `PrimDef`'s definition id being `0`.
#[must_use]
fn as_vec(&self) -> Vec<DefinitionId> {
vec![
self.int32,
self.int64,
self.uint32,
self.uint64,
self.float,
self.bool,
self.none,
self.range,
self.str,
self.exception,
self.option,
self.ndarray,
]
pub fn id(&self) -> DefinitionId {
DefinitionId(*self as usize)
}
/// Returns an iterator over all [`DefinitionId`]s of this instance in indeterminate order.
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> {
self.as_vec().into_iter()
/// Check if a definition ID is that of a [`PrimDef`].
#[must_use]
pub fn contains_id(id: DefinitionId) -> bool {
Self::iter().any(|prim| prim.id() == id)
}
/// Returns the primitive with the largest [`DefinitionId`].
/// Get the definition "simple name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::simple_name`].
///
derppening marked this conversation as resolved Outdated

The lifetime 'a can be elided; Same for the functions below.

The lifetime `'a` can be elided; Same for the functions below.
/// If the [`PrimDef`] is a class, this returns [`None`].
#[must_use]
pub fn max_id(&self) -> DefinitionId {
self.iter().max().unwrap()
pub fn simple_name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { simple_name, .. } => simple_name,
PrimDefDetails::PrimClass { .. } => {
panic!("PrimDef {self:?} has no simple_name as it is not a function.")
}
}
}
/// Get the definition "name" of this [`PrimDef`].
///
/// If the [`PrimDef`] is a function, this corresponds to [`TopLevelDef::Function::name`].
///
/// If the [`PrimDef`] is a class, this corresponds to [`TopLevelDef::Class::name`].
#[must_use]
pub fn name(&self) -> &'static str {
match self.details() {
PrimDefDetails::PrimFunction { name, .. } | PrimDefDetails::PrimClass { name } => name,
}
derppening marked this conversation as resolved Outdated

Could we just have one constructor with simple_name being an Option<&'static str>?

Could we just have one constructor with `simple_name` being an `Option<&'static str>`?
Outdated
Review

Sure, I will also change PrimDefDetails to have variants

  • PrimClass { name: &'static str } and.
  • PrimFunction { name: &'static str, simple_name: &'static str }.
Sure, I will also change `PrimDefDetails` to have variants - `PrimClass { name: &'static str }` and. - `PrimFunction { name: &'static str, simple_name: &'static str }`.
}
/// Get the associated details of this [`PrimDef`]
#[must_use]
pub fn details(self) -> PrimDefDetails {
fn class(name: &'static str) -> PrimDefDetails {
PrimDefDetails::PrimClass { name }
}
fn fun(name: &'static str, simple_name: Option<&'static str>) -> PrimDefDetails {
PrimDefDetails::PrimFunction { simple_name: simple_name.unwrap_or(name), name }
}
match self {
PrimDef::Int32 => class("int32"),
PrimDef::Int64 => class("int64"),
PrimDef::Float => class("float"),
PrimDef::Bool => class("bool"),
PrimDef::None => class("none"),
PrimDef::Range => class("range"),
PrimDef::Str => class("str"),
PrimDef::Exception => class("Exception"),
PrimDef::UInt32 => class("uint32"),
PrimDef::UInt64 => class("uint64"),
PrimDef::Option => class("Option"),
PrimDef::OptionIsSome => fun("Option.is_some", Some("is_some")),
PrimDef::OptionIsNone => fun("Option.is_none", Some("is_none")),
PrimDef::OptionUnwrap => fun("Option.unwrap", Some("unwrap")),
PrimDef::NDArray => class("ndarray"),
PrimDef::NDArrayCopy => fun("ndarray.copy", Some("copy")),
PrimDef::NDArrayFill => fun("ndarray.fill", Some("fill")),
PrimDef::FunInt32 => fun("int32", None),
PrimDef::FunInt64 => fun("int64", None),
PrimDef::FunUInt32 => fun("uint32", None),
PrimDef::FunUInt64 => fun("uint64", None),
PrimDef::FunFloat => fun("float", None),
PrimDef::FunNpNDArray => fun("np_ndarray", None),
PrimDef::FunNpEmpty => fun("np_empty", None),
PrimDef::FunNpZeros => fun("np_zeros", None),
PrimDef::FunNpOnes => fun("np_ones", None),
PrimDef::FunNpFull => fun("np_full", None),
PrimDef::FunNpArray => fun("np_array", None),
PrimDef::FunNpEye => fun("np_eye", None),
PrimDef::FunNpIdentity => fun("np_identity", None),
PrimDef::FunRound => fun("round", None),
PrimDef::FunRound64 => fun("round64", None),
PrimDef::FunNpRound => fun("np_round", None),
PrimDef::FunRange => fun("range", None),
PrimDef::FunStr => fun("str", None),
PrimDef::FunBool => fun("bool", None),
PrimDef::FunFloor => fun("floor", None),
PrimDef::FunFloor64 => fun("floor64", None),
PrimDef::FunNpFloor => fun("np_floor", None),
PrimDef::FunCeil => fun("ceil", None),
PrimDef::FunCeil64 => fun("ceil64", None),
PrimDef::FunNpCeil => fun("np_ceil", None),
PrimDef::FunLen => fun("len", None),
PrimDef::FunMin => fun("min", None),
PrimDef::FunNpMin => fun("np_min", None),
PrimDef::FunNpMinimum => fun("np_minimum", None),
PrimDef::FunMax => fun("max", None),
PrimDef::FunNpMax => fun("np_max", None),
PrimDef::FunNpMaximum => fun("np_maximum", None),
PrimDef::FunAbs => fun("abs", None),
PrimDef::FunNpIsNan => fun("np_isnan", None),
PrimDef::FunNpIsInf => fun("np_isinf", None),
PrimDef::FunNpSin => fun("np_sin", None),
PrimDef::FunNpCos => fun("np_cos", None),
PrimDef::FunNpExp => fun("np_exp", None),
PrimDef::FunNpExp2 => fun("np_exp2", None),
PrimDef::FunNpLog => fun("np_log", None),
PrimDef::FunNpLog10 => fun("np_log10", None),
PrimDef::FunNpLog2 => fun("np_log2", None),
PrimDef::FunNpFabs => fun("np_fabs", None),
PrimDef::FunNpSqrt => fun("np_sqrt", None),
PrimDef::FunNpRint => fun("np_rint", None),
PrimDef::FunNpTan => fun("np_tan", None),
PrimDef::FunNpArcsin => fun("np_arcsin", None),
PrimDef::FunNpArccos => fun("np_arccos", None),
PrimDef::FunNpArctan => fun("np_arctan", None),
PrimDef::FunNpSinh => fun("np_sinh", None),
PrimDef::FunNpCosh => fun("np_cosh", None),
PrimDef::FunNpTanh => fun("np_tanh", None),
PrimDef::FunNpArcsinh => fun("np_arcsinh", None),
PrimDef::FunNpArccosh => fun("np_arccosh", None),
PrimDef::FunNpArctanh => fun("np_arctanh", None),
PrimDef::FunNpExpm1 => fun("np_expm1", None),
PrimDef::FunNpCbrt => fun("np_cbrt", None),
PrimDef::FunSpSpecErf => fun("sp_spec_erf", None),
PrimDef::FunSpSpecErfc => fun("sp_spec_erfc", None),
PrimDef::FunSpSpecGamma => fun("sp_spec_gamma", None),
PrimDef::FunSpSpecGammaln => fun("sp_spec_gammaln", None),
PrimDef::FunSpSpecJ0 => fun("sp_spec_j0", None),
PrimDef::FunSpSpecJ1 => fun("sp_spec_j1", None),
PrimDef::FunNpArctan2 => fun("np_arctan2", None),
PrimDef::FunNpCopysign => fun("np_copysign", None),
PrimDef::FunNpFmax => fun("np_fmax", None),
PrimDef::FunNpFmin => fun("np_fmin", None),
PrimDef::FunNpLdExp => fun("np_ldexp", None),
PrimDef::FunNpHypot => fun("np_hypot", None),
PrimDef::FunNpNextAfter => fun("np_nextafter", None),
derppening marked this conversation as resolved Outdated

allowlist can just be a &[PrimDef].

`allowlist` can just be a `&[PrimDef]`.
PrimDef::FunSome => fun("Some", None),
}
}
}
/// The [definition IDs][DefinitionId] for primitive types.
pub const PRIMITIVE_DEF_IDS: PrimitiveDefinitionIds = PrimitiveDefinitionIds {
int32: DefinitionId(0),
int64: DefinitionId(1),
uint32: DefinitionId(8),
uint64: DefinitionId(9),
float: DefinitionId(2),
bool: DefinitionId(3),
none: DefinitionId(4),
range: DefinitionId(5),
str: DefinitionId(6),
exception: DefinitionId(7),
option: DefinitionId(10),
ndarray: DefinitionId(14),
};
/// Asserts that a [`PrimDef`] is in an allowlist.
///
/// Like `debug_assert!`, this statements of this function are only
/// enabled if `cfg!(debug_assertions)` is true.
pub fn debug_assert_prim_is_allowed(prim: PrimDef, allowlist: &[PrimDef]) {
if cfg!(debug_assertions) {
let allowed = allowlist.iter().any(|p| *p == prim);
assert!(
allowed,
"Disallowed primitive definition. Got {prim:?}, but expects it to be in {allowlist:?}"
);
}
}
impl TopLevelDef {
pub fn to_string(&self, unifier: &mut Unifier) -> String {
@ -116,42 +311,42 @@ impl TopLevelComposer {
pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
obj_id: PrimDef::Int32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
obj_id: PrimDef::Int64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
obj_id: PrimDef::Float.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
obj_id: PrimDef::Bool.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
obj_id: PrimDef::None.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
obj_id: PrimDef::Range.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
obj_id: PrimDef::Str.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
obj_id: PrimDef::Exception.id(),
fields: vec![
("__name__".into(), (int32, true)),
("__file__".into(), (str, true)),
@ -168,12 +363,12 @@ impl TopLevelComposer {
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
@ -190,7 +385,7 @@ impl TopLevelComposer {
vars: VarMap::from([(option_type_var.1, option_type_var.0)]),
}));
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
obj_id: PrimDef::Option.id(),
fields: vec![
("is_some".into(), (is_some_type_fun_ty, true)),
("is_none".into(), (is_some_type_fun_ty, true)),
@ -208,7 +403,8 @@ impl TopLevelComposer {
};
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_copy_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_copy_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![],
@ -219,13 +415,11 @@ impl TopLevelComposer {
]),
}));
let ndarray_fill_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "value".into(),
ty: ndarray_dtype_tvar.0,
default_value: None,
},
],
args: vec![FuncArg {
name: "value".into(),
ty: ndarray_dtype_tvar.0,
default_value: None,
}],
ret: none,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -233,7 +427,7 @@ impl TopLevelComposer {
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
obj_id: PrimDef::NDArray.id(),
fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)),
@ -393,9 +587,7 @@ impl TopLevelComposer {
if let TypeEnum::TVar { id, .. } = unifier.get_ty(var_ty).as_ref() {
Ok(*id)
} else {
Err(HashSet::from([
"not type var".to_string(),
]))
Err(HashSet::from(["not type var".to_string()]))
}
}
@ -412,25 +604,27 @@ impl TopLevelComposer {
let (
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other) else {
) = (this, other)
else {
unreachable!("this function must be called with function type")
};
// check args
let args_ok = this_args
.iter()
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
(name, type_var_to_concrete_def.get(ty).unwrap())
}))
.all(|(this, other)| {
if this.0 == &"self".into() && this.0 == other.0 {
true
} else {
this.0 == other.0
&& check_overload_type_annotation_compatible(this.1, other.1, unifier)
}
});
let args_ok =
this_args
.iter()
.map(|FuncArg { name, ty, .. }| (name, type_var_to_concrete_def.get(ty).unwrap()))
.zip(other_args.iter().map(|FuncArg { name, ty, .. }| {
(name, type_var_to_concrete_def.get(ty).unwrap())
}))
.all(|(this, other)| {
if this.0 == &"self".into() && this.0 == other.0 {
true
} else {
this.0 == other.0
&& check_overload_type_annotation_compatible(this.1, other.1, unifier)
}
});
// check rets
let ret_ok = check_overload_type_annotation_compatible(
@ -473,12 +667,10 @@ impl TopLevelComposer {
}
} =>
{
return Err(HashSet::from([
format!(
"redundant type annotation for class fields at {}",
s.location
),
]))
return Err(HashSet::from([format!(
"redundant type annotation for class fields at {}",
s.location
)]))
}
ast::StmtKind::Assign { targets, .. } => {
for t in targets {
@ -602,112 +794,109 @@ pub fn parse_parameter_default_value(
Constant::Tuple(tuple) => Ok(SymbolValue::Tuple(
tuple.iter().map(|x| handle_constant(x, loc)).collect::<Result<Vec<_>, _>>()?,
)),
Constant::None => Err(HashSet::from([
format!(
"`None` is not supported, use `none` for option type instead ({loc})"
),
])),
Constant::None => Err(HashSet::from([format!(
"`None` is not supported, use `none` for option type instead ({loc})"
)])),
_ => unimplemented!("this constant is not supported at {}", loc),
}
}
match &default.node {
ast::ExprKind::Constant { value, .. } => handle_constant(value, &default.location),
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => {
match &func.node {
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<i64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::I64(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location)
])),
}
ast::ExprKind::Call { func, args, .. } if args.len() == 1 => match &func.node {
ast::ExprKind::Name { id, .. } if *id == "int64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<i64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::I64(v)),
_ => Err(HashSet::from([format!(
"default param value out of range at {}",
default.location
)])),
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u32, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U32(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location),
])),
}
_ => Err(HashSet::from([format!(
"only allow constant integer here at {}",
default.location
)])),
},
ast::ExprKind::Name { id, .. } if *id == "uint32".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u32, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U32(v)),
_ => Err(HashSet::from([format!(
"default param value out of range at {}",
default.location
)])),
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U64(v)),
_ => Err(HashSet::from([
format!("default param value out of range at {}", default.location),
])),
}
_ => Err(HashSet::from([format!(
"only allow constant integer here at {}",
default.location
)])),
},
ast::ExprKind::Name { id, .. } if *id == "uint64".into() => match &args[0].node {
ast::ExprKind::Constant { value: Constant::Int(v), .. } => {
let v: Result<u64, _> = (*v).try_into();
match v {
Ok(v) => Ok(SymbolValue::U64(v)),
_ => Err(HashSet::from([format!(
"default param value out of range at {}",
default.location
)])),
}
_ => Err(HashSet::from([
format!("only allow constant integer here at {}", default.location),
]))
}
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(
SymbolValue::OptionSome(
Box::new(parse_parameter_default_value(&args[0], resolver)?)
)
),
_ => Err(HashSet::from([
format!("unsupported default parameter at {}", default.location),
])),
}
}
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts
.iter()
.map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()?
_ => Err(HashSet::from([format!(
"only allow constant integer here at {}",
default.location
)])),
},
ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok(SymbolValue::OptionSome(
Box::new(parse_parameter_default_value(&args[0], resolver)?),
)),
_ => Err(HashSet::from([format!(
"unsupported default parameter at {}",
default.location
)])),
},
ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(
elts.iter()
.map(|x| parse_parameter_default_value(x, resolver))
.collect::<Result<Vec<_>, _>>()?,
)),
ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone),
ast::ExprKind::Name { id, .. } => {
resolver.get_default_param_value(default).ok_or_else(
|| HashSet::from([
format!(
"`{}` cannot be used as a default parameter at {} \
resolver.get_default_param_value(default).ok_or_else(|| {
HashSet::from([format!(
"`{}` cannot be used as a default parameter at {} \
(not primitive type, option or tuple / not defined?)",
id,
default.location
),
])
)
id, default.location
)])
})
}
_ => Err(HashSet::from([
format!(
"unsupported default parameter (not primitive type, option or tuple) at {}",
default.location
),
]))
_ => Err(HashSet::from([format!(
"unsupported default parameter (not primitive type, option or tuple) at {}",
default.location
)])),
}
}
/// Obtains the element type of an array-like type.
pub fn arraylike_flatten_element_type(unifier: &mut Unifier, ty: Type) -> Type {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray =>
unpack_ndarray_var_tys(unifier, ty).0,
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
unpack_ndarray_var_tys(unifier, ty).0
}
TypeEnum::TList { ty } => arraylike_flatten_element_type(unifier, *ty),
_ => ty
_ => ty,
}
}
/// Obtains the number of dimensions of an array-like type.
pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
match &*unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let ndims = unpack_ndarray_var_tys(unifier, ty).1;
let TypeEnum::TLiteral { values, .. } = &*unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", unifier.stringify(ndims))
@ -721,6 +910,6 @@ pub fn arraylike_get_ndims(unifier: &mut Unifier, ty: Type) -> u64 {
}
TypeEnum::TList { ty } => arraylike_get_ndims(unifier, *ty) + 1,
_ => 0
_ => 0,
}
}

View File

@ -8,7 +8,9 @@ use std::{
use super::codegen::CodeGenContext;
use super::typecheck::type_inferencer::PrimitiveStore;
use super::typecheck::typedef::{FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap};
use super::typecheck::typedef::{
FunSignature, FuncArg, SharedUnifier, Type, TypeEnum, Unifier, VarMap,
};
use crate::{
codegen::CodeGenerator,
symbol_resolver::{SymbolResolver, ValueEnum},
@ -32,16 +34,15 @@ use type_annotation::*;
#[cfg(test)]
mod test;
type GenCallCallback =
dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
&mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send
+ Sync;
type GenCallCallback = dyn for<'ctx, 'a> Fn(
&mut CodeGenContext<'ctx, 'a>,
Option<(Type, ValueEnum<'ctx>)>,
(&FunSignature, DefinitionId),
Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
&mut dyn CodeGenerator,
) -> Result<Option<BasicValueEnum<'ctx>>, String>
+ Send
+ Sync;
pub struct GenCall {
fp: Box<GenCallCallback>,
@ -53,7 +54,7 @@ impl GenCall {
GenCall { fp }
}
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// Creates a dummy instance of [`GenCall`], which invokes [`unreachable!()`] with the given
/// `reason`.
#[must_use]
pub fn create_dummy(reason: String) -> GenCall {

View File

@ -1,14 +1,14 @@
use itertools::Itertools;
use crate::{
toplevel::helper::PRIMITIVE_DEF_IDS,
toplevel::helper::PrimDef,
typecheck::{
type_inferencer::PrimitiveStore,
typedef::{Type, TypeEnum, Unifier, VarMap},
},
};
use itertools::Itertools;
/// Creates a `ndarray` [`Type`] with the given type arguments.
///
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
@ -37,15 +37,13 @@ pub fn subst_ndarray_tvars(
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
if dtype.is_none() && ndims.is_none() {
return ndarray
return ndarray;
}
let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id)
.collect_vec();
let tvar_ids = params.iter().map(|(obj_id, _)| *obj_id).collect_vec();
debug_assert_eq!(tvar_ids.len(), 2);
let mut tvar_subst = VarMap::new();
@ -59,45 +57,29 @@ pub fn subst_ndarray_tvars(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
fn unpack_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
) -> Vec<(u32, Type)> {
fn unpack_ndarray_tvars(unifier: &mut Unifier, ndarray: Type) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
debug_assert_eq!(*obj_id, PrimDef::NDArray.id());
debug_assert_eq!(params.len(), 2);
params.iter()
params
.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(
unifier: &mut Unifier,
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
pub fn unpack_ndarray_var_ids(unifier: &mut Unifier, ndarray: Type) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.0).collect_tuple().unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple()
.unwrap()
pub fn unpack_ndarray_var_tys(unifier: &mut Unifier, ndarray: Type) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray).into_iter().map(|v| v.1).collect_tuple().unwrap()
}

View File

@ -65,7 +65,11 @@ impl SymbolResolver for Resolver {
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).cloned()
self.0
.id_to_def
.lock()
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}

View File

@ -1,7 +1,7 @@
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::typecheck::typedef::VarMap;
use super::*;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PrimDef;
use crate::typecheck::typedef::VarMap;
use nac3parser::ast::Constant;
#[derive(Clone, Debug)]
@ -29,9 +29,7 @@ impl TypeAnnotation {
Primitive(ty) | TypeVar(ty) => unifier.stringify(*ty),
CustomClass { id, params } => {
let class_name = if let Some(ref top) = unifier.top_level {
if let TopLevelDef::Class { name, .. } =
&*top.definitions.read()[id.0].read()
{
if let TopLevelDef::Class { name, .. } = &*top.definitions.read()[id.0].read() {
(*name).into()
} else {
unreachable!()
@ -39,24 +37,26 @@ impl TypeAnnotation {
} else {
format!("class_def_{}", id.0)
};
format!(
"{}{}",
class_name,
{
let param_list = params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() {
String::new()
} else {
format!("[{param_list}]")
}
format!("{}{}", class_name, {
let param_list =
params.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ");
if param_list.is_empty() {
String::new()
} else {
format!("[{param_list}]")
}
)
})
}
Literal(values) => {
format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", "))
}
Literal(values) => format!("Literal({})", values.iter().map(|v| format!("{v:?}")).join(", ")),
Virtual(ty) => format!("virtual[{}]", ty.stringify(unifier)),
List(ty) => format!("list[{}]", ty.stringify(unifier)),
Tuple(types) => {
format!("tuple[{}]", types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", "))
format!(
"tuple[{}]",
types.iter().map(|p| p.stringify(unifier)).collect_vec().join(", ")
)
}
}
}
@ -95,7 +95,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
} else if id == &"str".into() {
Ok(TypeAnnotation::Primitive(primitives.str))
} else if id == &"Exception".into() {
Ok(TypeAnnotation::CustomClass { id: PRIMITIVE_DEF_IDS.exception, params: Vec::default() })
Ok(TypeAnnotation::CustomClass { id: PrimDef::Exception.id(), params: Vec::default() })
} else if let Ok(obj_id) = resolver.get_identifier_def(*id) {
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
@ -103,12 +103,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
return Err(HashSet::from([
format!(
"function cannot be used as a type (at {})",
expr.location
),
]))
return Err(HashSet::from([format!(
"function cannot be used as a type (at {})",
expr.location
)]));
}
} else {
locked.get(&obj_id).unwrap().clone()
@ -116,13 +114,11 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
};
// check param number here
if !type_vars.is_empty() {
return Err(HashSet::from([
format!(
"expect {} type variable parameter but got 0 (at {})",
type_vars.len(),
expr.location,
),
]))
return Err(HashSet::from([format!(
"expect {} type variable parameter but got 0 (at {})",
type_vars.len(),
expr.location,
)]));
}
Ok(TypeAnnotation::CustomClass { id: obj_id, params: vec![] })
} else if let Ok(ty) = resolver.get_symbol_type(unifier, top_level_defs, primitives, *id) {
@ -131,14 +127,16 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
unifier.unify(var, ty).unwrap();
Ok(TypeAnnotation::TypeVar(ty))
} else {
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
}
} else {
Err(HashSet::from([
format!("`{}` is not a valid type annotation (at {})", id, expr.location),
]))
Err(HashSet::from([format!(
"`{}` is not a valid type annotation (at {})",
id, expr.location
)]))
}
};
@ -147,11 +145,13 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
slice: &ast::Expr<T>,
unifier: &mut Unifier,
mut locked: HashMap<DefinitionId, Vec<Type>>| {
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()].contains(id)
if ["virtual".into(), "Generic".into(), "list".into(), "tuple".into(), "Option".into()]
.contains(id)
{
return Err(HashSet::from([
format!("keywords cannot be class name (at {})", expr.location),
]))
return Err(HashSet::from([format!(
"keywords cannot be class name (at {})",
expr.location
)]));
}
let obj_id = resolver.get_identifier_def(*id)?;
let type_vars = {
@ -174,14 +174,12 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
vec![slice]
};
if type_vars.len() != params_ast.len() {
return Err(HashSet::from([
format!(
"expect {} type parameters but got {} (at {})",
type_vars.len(),
params_ast.len(),
params_ast[0].location,
),
]))
return Err(HashSet::from([format!(
"expect {} type parameters but got {} (at {})",
type_vars.len(),
params_ast.len(),
params_ast[0].location,
)]));
}
let result = params_ast
.iter()
@ -210,7 +208,7 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
"application of type vars to generic class is not currently supported (at {})",
params_ast[0].location
),
]))
]));
}
};
Ok(TypeAnnotation::CustomClass { id: obj_id, params: param_type_infos })
@ -309,9 +307,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
// Literal
ast::ExprKind::Subscript { value, slice, .. }
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} => {
if {
matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Literal".into())
} =>
{
let tup_elts = {
if let ast::ExprKind::Tuple { elts, .. } = &slice.node {
elts.as_slice()
@ -321,20 +320,18 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
};
let type_annotations = tup_elts
.iter()
.map(|e| {
match &e.node {
ast::ExprKind::Constant { value, .. } => Ok(
TypeAnnotation::Literal(vec![value.clone()]),
),
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
.map(|e| match &e.node {
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
_ => parse_ast_to_type_annotation_kinds(
resolver,
top_level_defs,
unifier,
primitives,
e,
locked.clone(),
),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
@ -347,9 +344,10 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if type_annotations.len() == 1 {
Ok(TypeAnnotation::Literal(type_annotations))
} else {
Err(HashSet::from([
format!("multiple literal bounds are currently unsupported (at {})", value.location)
]))
Err(HashSet::from([format!(
"multiple literal bounds are currently unsupported (at {})",
value.location
)]))
}
}
@ -358,19 +356,19 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
if let ast::ExprKind::Name { id, .. } = &value.node {
class_name_handle(id, slice, unifier, locked)
} else {
Err(HashSet::from([
format!("unsupported expression type for class name (at {})", value.location)
]))
Err(HashSet::from([format!(
"unsupported expression type for class name (at {})",
value.location
)]))
}
}
ast::ExprKind::Constant { value, .. } => {
Ok(TypeAnnotation::Literal(vec![value.clone()]))
}
ast::ExprKind::Constant { value, .. } => Ok(TypeAnnotation::Literal(vec![value.clone()])),
_ => Err(HashSet::from([
format!("unsupported expression for type annotation (at {})", expr.location),
])),
_ => Err(HashSet::from([format!(
"unsupported expression for type annotation (at {})",
expr.location
)])),
}
}
@ -381,7 +379,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs: &[Arc<RwLock<TopLevelDef>>],
unifier: &mut Unifier,
ann: &TypeAnnotation,
subst_list: &mut Option<Vec<Type>>
subst_list: &mut Option<Vec<Type>>,
) -> Result<Type, HashSet<String>> {
match ann {
TypeAnnotation::CustomClass { id: obj_id, params } => {
@ -392,24 +390,17 @@ pub fn get_type_from_type_annotation_kinds(
};
if type_vars.len() != params.len() {
return Err(HashSet::from([
format!(
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
),
]))
return Err(HashSet::from([format!(
"unexpected number of type parameters: expected {} but got {}",
type_vars.len(),
params.len()
)]));
}
let param_ty = params
.iter()
.map(|x| {
get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
x,
subst_list
)
get_type_from_type_annotation_kinds(top_level_defs, unifier, x, subst_list)
})
.collect::<Result<Vec<_>, _>>()?;
@ -419,7 +410,14 @@ pub fn get_type_from_type_annotation_kinds(
let mut result = VarMap::new();
for (tvar, p) in type_vars.iter().zip(param_ty) {
match unifier.get_ty(*tvar).as_ref() {
TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false } => {
TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
} => {
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
@ -434,18 +432,16 @@ pub fn get_type_from_type_annotation_kinds(
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)
]))
return Err(HashSet::from([format!(
"cannot apply type {} to type variable with id {:?}",
unifier.internal_stringify(
p,
&mut |id| format!("class{id}"),
&mut |id| format!("typevar{id}"),
&mut None
),
*id
)]));
}
}
@ -454,24 +450,18 @@ pub fn get_type_from_type_annotation_kinds(
let ok: bool = {
// create a temp type var and unify to check compatibility
p == *tvar || {
let temp = unifier.get_fresh_const_generic_var(
ty,
*name,
*loc,
);
let temp = unifier.get_fresh_const_generic_var(ty, *name, *loc);
unifier.unify(temp.0, p).is_ok()
}
};
if ok {
result.insert(*id, p);
} else {
return Err(HashSet::from([
format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
),
]))
return Err(HashSet::from([format!(
"cannot apply type {} to type variable {}",
unifier.stringify(p),
name.unwrap_or_else(|| format!("typevar{id}").into()),
)]));
}
}
@ -507,7 +497,8 @@ pub fn get_type_from_type_annotation_kinds(
}
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Literal(values) => {
let values = values.iter()
let values = values
.iter()
.map(SymbolValue::from_constant_inferred)
.collect::<Result<Vec<_>, _>>()
.map_err(|err| HashSet::from([err]))?;
@ -520,7 +511,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
subst_list,
)?;
Ok(unifier.add_ty(TypeEnum::TVirtual { ty }))
}
@ -529,7 +520,7 @@ pub fn get_type_from_type_annotation_kinds(
top_level_defs,
unifier,
ty.as_ref(),
subst_list
subst_list,
)?;
Ok(unifier.add_ty(TypeEnum::TList { ty }))
}
@ -607,7 +598,8 @@ pub fn check_overload_type_annotation_compatible(
let (
TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b) else {
) = (a, b)
else {
unreachable!("must be type var")
};

View File

@ -2,15 +2,17 @@ use crate::typecheck::typedef::TypeEnum;
use super::type_inferencer::Inferencer;
use super::typedef::Type;
use nac3parser::ast::{self, Constant, Expr, ExprKind, Operator::{LShift, RShift}, Stmt, StmtKind, StrRef};
use nac3parser::ast::{
self, Constant, Expr, ExprKind,
Operator::{LShift, RShift},
Stmt, StmtKind, StrRef,
};
use std::{collections::HashSet, iter::once};
impl<'a> Inferencer<'a> {
fn should_have_value(&mut self, expr: &Expr<Option<Type>>) -> Result<(), HashSet<String>> {
if matches!(expr.custom, Some(ty) if self.unifier.unioned(ty, self.primitives.none)) {
Err(HashSet::from([
format!("Error at {}: cannot have value none", expr.location),
]))
Err(HashSet::from([format!("Error at {}: cannot have value none", expr.location)]))
} else {
Ok(())
}
@ -22,9 +24,9 @@ impl<'a> Inferencer<'a> {
defined_identifiers: &mut HashSet<StrRef>,
) -> Result<(), HashSet<String>> {
match &pattern.node {
ExprKind::Name { id, .. } if id == &"none".into() => Err(HashSet::from([
format!("cannot assign to a `none` (at {})", pattern.location),
])),
ExprKind::Name { id, .. } if id == &"none".into() => {
Err(HashSet::from([format!("cannot assign to a `none` (at {})", pattern.location)]))
}
ExprKind::Name { id, .. } => {
if !defined_identifiers.contains(id) {
defined_identifiers.insert(*id);
@ -44,20 +46,17 @@ impl<'a> Inferencer<'a> {
self.should_have_value(value)?;
self.check_expr(slice, defined_identifiers)?;
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return Err(HashSet::from([
format!(
"Error at {}: cannot assign to tuple element",
value.location
),
]))
return Err(HashSet::from([format!(
"Error at {}: cannot assign to tuple element",
value.location
)]));
}
Ok(())
}
ExprKind::Constant { .. } => {
Err(HashSet::from([
format!("cannot assign to a constant (at {})", pattern.location),
]))
}
ExprKind::Constant { .. } => Err(HashSet::from([format!(
"cannot assign to a constant (at {})",
pattern.location
)])),
_ => self.check_expr(pattern, defined_identifiers),
}
}
@ -69,14 +68,14 @@ impl<'a> Inferencer<'a> {
) -> Result<(), HashSet<String>> {
// there are some cases where the custom field is None
if let Some(ty) = &expr.custom {
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) && !self.unifier.is_concrete(*ty, &self.function_data.bound_variables) {
return Err(HashSet::from([
format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
)
]))
if !matches!(&expr.node, ExprKind::Constant { value: Constant::Ellipsis, .. })
&& !self.unifier.is_concrete(*ty, &self.function_data.bound_variables)
{
return Err(HashSet::from([format!(
"expected concrete type at {} but got {}",
expr.location,
self.unifier.get_ty(*ty).get_type_name()
)]));
}
}
match &expr.node {
@ -96,12 +95,10 @@ impl<'a> Inferencer<'a> {
self.defined_identifiers.insert(*id);
}
Err(e) => {
return Err(HashSet::from([
format!(
"type error at identifier `{}` ({}) at {}",
id, e, expr.location
)
]))
return Err(HashSet::from([format!(
"type error at identifier `{}` ({}) at {}",
id, e, expr.location
)]))
}
}
}
@ -127,17 +124,13 @@ impl<'a> Inferencer<'a> {
// Check whether a bitwise shift has a negative RHS constant value
if *op == LShift || *op == RShift {
if let ExprKind::Constant { value, .. } = &right.node {
let Constant::Int(rhs_val) = value else {
unreachable!()
};
let Constant::Int(rhs_val) = value else { unreachable!() };
if *rhs_val < 0 {
return Err(HashSet::from([
format!(
"shift count is negative at {}",
right.location
),
]))
return Err(HashSet::from([format!(
"shift count is negative at {}",
right.location
)]));
}
}
}
@ -214,16 +207,16 @@ impl<'a> Inferencer<'a> {
/// is freed when the function returns.
fn check_return_value_ty(&mut self, ret_ty: Type) -> bool {
match &*self.unifier.get_ty_immutable(ret_ty) {
TypeEnum::TObj { .. } => {
[
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
].iter().any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty))
}
TypeEnum::TObj { .. } => [
self.primitives.int32,
self.primitives.int64,
self.primitives.uint32,
self.primitives.uint64,
self.primitives.float,
self.primitives.bool,
]
.iter()
.any(|allowed_ty| self.unifier.unioned(ret_ty, *allowed_ty)),
TypeEnum::TTuple { ty } => ty.iter().all(|t| self.check_return_value_ty(*t)),
_ => false,
}
@ -330,8 +323,11 @@ impl<'a> Inferencer<'a> {
if let Some(ret_ty) = value.custom {
// Explicitly allow ellipsis as a return value, as the type of the ellipsis is contextually
// inferred and just generates an unconditional assertion
if matches!(value.node, ExprKind::Constant { value: Constant::Ellipsis, .. }) {
return Ok(true)
if matches!(
value.node,
ExprKind::Constant { value: Constant::Ellipsis, .. }
) {
return Ok(true);
}
if !self.check_return_value_ty(ret_ty) {
@ -341,7 +337,7 @@ impl<'a> Inferencer<'a> {
self.unifier.stringify(ret_ty),
value.location,
),
]))
]));
}
}
}

View File

@ -1,16 +1,17 @@
use std::cmp::max;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::helper::PrimDef;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
};
use itertools::Itertools;
use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::cmp::max;
use std::collections::HashMap;
use std::rc::Rc;
use itertools::Itertools;
use strum::IntoEnumIterator;
#[must_use]
pub fn binop_name(op: &Operator) -> &'static str {
@ -255,7 +256,14 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
/// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]);
impl_binop(
unifier,
store,
ty,
&[store.int32, store.uint32],
Some(ty),
&[Operator::LShift, Operator::RShift],
);
}
/// `Div`
@ -297,7 +305,7 @@ pub fn impl_matmul(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::MatMult]);
}
@ -353,8 +361,8 @@ pub fn typeof_ndarray_broadcast(
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_left_ndarray = left.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = right.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
assert!(is_left_ndarray || is_right_ndarray);
@ -375,7 +383,8 @@ pub fn typeof_ndarray_broadcast(
_ => unreachable!(),
};
let res_ndims = left_ty_ndims.into_iter()
let res_ndims = left_ty_ndims
.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
@ -390,11 +399,7 @@ pub fn typeof_ndarray_broadcast(
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
(left, right)
} else {
(right, left)
};
let (ndarray_ty, scalar_ty) = if is_left_ndarray { (left, right) } else { (right, left) };
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
@ -424,21 +429,17 @@ pub fn typeof_binop(
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(match op {
Operator::Add
| Operator::Sub
| Operator::Mult
| Operator::Mod
| Operator::FloorDiv => {
Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
return Ok(None);
}
}
@ -464,12 +465,14 @@ pub fn typeof_binop(
(2, 2) => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
(lhs, rhs) if lhs == 0 || rhs == 0 => {
return Err(format!(
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
(rhs == 0) as u8
))
"Input operand {} does not have enough dimensions (has {lhs}, requires {rhs})",
(rhs == 0) as u8
))
}
(lhs, rhs) => {
return Err(format!("ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"))
return Err(format!(
"ndarray.__matmul__ on {lhs}D and {rhs}D operands not supported"
))
}
}
}
@ -480,29 +483,35 @@ pub fn typeof_binop(
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None)
return Ok(None);
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
} else if [
primitives.int32,
primitives.int64,
primitives.uint32,
primitives.uint64,
primitives.float,
]
.into_iter()
.any(|ty| unifier.unioned(lhs, ty))
{
lhs
} else {
return Ok(None)
return Ok(None);
}
}
Operator::LShift
| Operator::RShift => lhs,
Operator::BitOr
| Operator::BitXor
| Operator::BitAnd => {
Operator::LShift | Operator::RShift => lhs,
Operator::BitOr | Operator::BitXor | Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
return Ok(None);
}
}
}))
@ -516,45 +525,46 @@ pub fn typeof_unaryop(
) -> Result<Option<Type>, String> {
let operand_obj_id = operand.obj_id(unifier);
if *op == Unaryop::Not && operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
if *op == Unaryop::Not
&& operand_obj_id.is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap())
{
return Err(
"The truth value of an array with more than one element is ambiguous".to_string()
);
}
Ok(match *op {
Unaryop::Not => {
match operand_obj_id {
Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
Some(_) => Some(primitives.bool),
_ => None
}
}
Unaryop::Not => match operand_obj_id {
Some(v) if v == PrimDef::NDArray.id() => Some(operand),
Some(_) => Some(primitives.bool),
_ => None,
},
Unaryop::Invert => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
}
}
Unaryop::UAdd
| Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
Unaryop::UAdd | Unaryop::USub => {
if operand_obj_id.is_some_and(|id| id == PrimDef::NDArray.id()) {
let (dtype, _) = unpack_ndarray_var_tys(unifier, operand);
if dtype.obj_id(unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
if dtype.obj_id(unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
return Err(if *op == Unaryop::UAdd {
"The ufunc 'positive' cannot be applied to ndarray[bool, N]".to_string()
} else {
"The numpy boolean negative, the `-` operator, is not supported, use the `~` operator function instead.".to_string()
})
});
}
Some(operand)
} else if operand_obj_id.is_some_and(|id| id == PRIMITIVE_DEF_IDS.bool) {
} else if operand_obj_id.is_some_and(|id| id == PrimDef::Bool.id()) {
Some(primitives.int32)
} else if operand_obj_id.is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
} else if operand_obj_id.is_some_and(|id| PrimDef::iter().any(|prim| id == prim.id())) {
Some(operand)
} else {
None
@ -571,12 +581,8 @@ pub fn typeof_cmpop(
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id());
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
@ -586,7 +592,7 @@ pub fn typeof_cmpop(
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None)
return Ok(None);
}))
}
@ -643,11 +649,19 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let ndarray_usized_ndims_tvar =
unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t =
make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_basic_arithmetic(
unifier,
store,
ndarray_t,
&[ndarray_unsized_t, ndarray_unsized_dtype_t],
None,
);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);

View File

@ -89,10 +89,7 @@ impl<'a> Display for DisplayTypeError<'a> {
IncorrectArgType { name, expected, got } => {
let expected = self.unifier.stringify_with_notes(*expected, &mut notes);
let got = self.unifier.stringify_with_notes(*got, &mut notes);
write!(
f,
"Incorrect argument type for {name}. Expected {expected}, but got {got}"
)
write!(f, "Incorrect argument type for {name}. Expected {expected}, but got {got}")
}
FieldUnificationError { field, types, loc } => {
let lhs = self.unifier.stringify_with_notes(types.0, &mut notes);

View File

@ -1,21 +1,25 @@
use std::collections::{HashMap, HashSet};
use std::convert::{From, TryInto};
use std::iter::once;
use std::{cell::RefCell, sync::Arc};
use std::ops::Not;
use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PRIMITIVE_DEF_IDS},
helper::{arraylike_flatten_element_type, arraylike_get_ndims, PrimDef},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
};
use itertools::{Itertools, izip};
use nac3parser::ast::{self, fold::{self, Fold}, Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef};
use itertools::{izip, Itertools};
use nac3parser::ast::{
self,
fold::{self, Fold},
Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef,
};
#[cfg(test)]
mod test;
@ -187,9 +191,12 @@ impl<'a> Fold<()> for Inferencer<'a> {
}
if let Some(old_typ) = self.variable_mapping.insert(name, typ) {
let loc = handler.location;
self.unifier.unify(old_typ, typ).map_err(|e| HashSet::from([
e.at(Some(loc)).to_display(self.unifier).to_string(),
]))?;
self.unifier.unify(old_typ, typ).map_err(|e| {
HashSet::from([e
.at(Some(loc))
.to_display(self.unifier)
.to_string()])
})?;
}
}
let mut type_ = naive_folder.fold_expr(*type_)?;
@ -234,8 +241,12 @@ impl<'a> Fold<()> for Inferencer<'a> {
self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?;
} else {
let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => todo!(),
TypeEnum::TList { .. } => {
self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() })
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
todo!()
}
_ => unreachable!(),
};
self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?;
@ -273,13 +284,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
let targets: Result<Vec<_>, _> = targets
.into_iter()
.map(|target| {
let ExprKind::Name { id, ctx } = target.node else {
unreachable!()
};
let ExprKind::Name { id, ctx } = target.node else { unreachable!() };
self.defined_identifiers.insert(id);
let target_ty = if let Some(ty) = self.variable_mapping.get(&id)
{
let target_ty = if let Some(ty) = self.variable_mapping.get(&id) {
*ty
} else {
let unifier: &mut Unifier = self.unifier;
@ -305,8 +313,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
})
.collect();
let loc = node.location;
let targets = targets
.map_err(|e| HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()]))?;
let targets = targets.map_err(|e| {
HashSet::from([e.at(Some(loc)).to_display(self.unifier).to_string()])
})?;
return Ok(Located {
location: node.location,
node: ast::StmtKind::Assign {
@ -463,7 +472,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
self.unify(test.custom.unwrap(), self.primitives.bool, &test.location)?;
match msg {
Some(m) => self.unify(m.custom.unwrap(), self.primitives.str, &m.location)?,
None => ()
None => (),
}
}
_ => return report_error("Unsupported statement type", stmt.location),
@ -485,9 +494,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
_ => fold::fold_expr(self, node)?,
};
let custom = match &expr.node {
ExprKind::Constant { value, .. } => {
Some(self.infer_constant(value, &expr.location)?)
}
ExprKind::Constant { value, .. } => Some(self.infer_constant(value, &expr.location)?),
ExprKind::Name { id, .. } => {
// the name `none` is special since it may have different types
if id == &"none".into() {
@ -497,7 +504,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
let var_map = params
.iter()
.map(|(id_var, ty)| {
let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) else {
let TypeEnum::TVar { id, range, name, loc, .. } =
&*self.unifier.get_ty(*ty)
else {
unreachable!()
};
@ -552,9 +561,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::IfExp { test, body, orelse } => {
Some(self.infer_if_expr(test, body.as_ref(), orelse.as_ref())?)
}
ExprKind::ListComp { .. }
| ExprKind::Lambda { .. }
| ExprKind::Call { .. } => expr.custom, // already computed
ExprKind::ListComp { .. } | ExprKind::Lambda { .. } | ExprKind::Call { .. } => {
expr.custom
} // already computed
ExprKind::Slice { .. } => {
// slices aren't exactly ranges, but for our purposes this should suffice
Some(self.primitives.range)
@ -575,11 +584,9 @@ impl<'a> Inferencer<'a> {
}
fn unify(&mut self, a: Type, b: Type, location: &Location) -> Result<(), HashSet<String>> {
self.unifier
.unify(a, b)
.map_err(|e| HashSet::from([
e.at(Some(*location)).to_display(self.unifier).to_string(),
]))
self.unifier.unify(a, b).map_err(|e| {
HashSet::from([e.at(Some(*location)).to_display(self.unifier).to_string()])
})
}
fn infer_pattern(&mut self, pattern: &ast::Expr<()>) -> Result<(), HashSet<String>> {
@ -622,12 +629,15 @@ impl<'a> Inferencer<'a> {
loc: Some(location),
};
if let Some(ret) = ret {
self.unifier.unify(sign.ret, ret)
self.unifier
.unify(sign.ret, ret)
.map_err(|err| {
format!("Cannot unify {} <: {} - {:?}",
self.unifier.stringify(sign.ret),
self.unifier.stringify(ret),
TypeError::new(err.kind, Some(location)))
format!(
"Cannot unify {} <: {} - {:?}",
self.unifier.stringify(sign.ret),
self.unifier.stringify(ret),
TypeError::new(err.kind, Some(location))
)
})
.unwrap();
}
@ -638,9 +648,12 @@ impl<'a> Inferencer<'a> {
.map(|v| v.name)
.rev()
.collect();
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| HashSet::from([
e.at(Some(location)).to_display(self.unifier).to_string(),
]))?;
self.unifier.unify_call(&call, ty, sign, &required).map_err(|e| {
HashSet::from([e
.at(Some(location))
.to_display(self.unifier)
.to_string()])
})?;
return Ok(sign.ret);
}
}
@ -815,7 +828,7 @@ impl<'a> Inferencer<'a> {
keywords: &[Located<ast::KeywordData>],
) -> Result<Option<ast::Expr<Option<Type>>>, HashSet<String>> {
let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else {
return Ok(None)
return Ok(None);
};
// handle special functions that cannot be typed in the usual way...
@ -824,7 +837,7 @@ impl<'a> Inferencer<'a> {
return report_error(
"`virtual` can only accept 1/2 positional arguments",
*func_location,
)
);
}
let arg0 = self.fold_expr(args.remove(0))?;
let ty = if let Some(arg) = args.pop() {
@ -852,19 +865,19 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords: vec![],
},
}))
}));
}
if [
"int32",
"float",
"bool",
"round",
"round64",
"np_isnan",
"np_isinf",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let target_ty = if id == &"int32".into() || id == &"round".into() || id == &"floor".into() || id == &"ceil".into() {
if ["int32", "float", "bool", "round", "round64", "np_isnan", "np_isinf"]
.iter()
.any(|fun_id| id == &(*fun_id).into())
&& args.len() == 1
{
let target_ty = if id == &"int32".into()
|| id == &"round".into()
|| id == &"floor".into()
|| id == &"ceil".into()
{
self.primitives.int32
} else if id == &"round64".into() || id == &"floor64".into() || id == &"ceil64".into() {
self.primitives.int64
@ -872,12 +885,15 @@ impl<'a> Inferencer<'a> {
self.primitives.float
} else if id == &"bool".into() || id == &"np_isnan".into() || id == &"np_isinf".into() {
self.primitives.bool
} else { unreachable!() };
} else {
unreachable!()
};
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
@ -886,13 +902,11 @@ impl<'a> Inferencer<'a> {
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
args: vec![FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
}],
ret,
vars: VarMap::new(),
}));
@ -909,17 +923,15 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords: vec![],
},
}))
}));
}
if [
"np_min",
"np_max",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
ndarray_dtype
@ -928,13 +940,11 @@ impl<'a> Inferencer<'a> {
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "a".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
args: vec![FuncArg {
name: "a".into(),
ty: arg0.custom.unwrap(),
default_value: None,
}],
ret,
vars: VarMap::new(),
}));
@ -951,7 +961,7 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords: vec![],
},
}))
}));
}
if [
@ -964,29 +974,32 @@ impl<'a> Inferencer<'a> {
"np_ldexp",
"np_hypot",
"np_nextafter",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 2 {
]
.iter()
.any(|fun_id| id == &(*fun_id).into())
&& args.len() == 2
{
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let arg1 = self.fold_expr(args.remove(0))?;
let arg1_ty = arg1.custom.unwrap();
let arg0_dtype = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
} else {
arg0_ty
};
let arg0_dtype =
if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg0_ty).0
} else {
arg0_ty
};
let arg1_dtype = if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
} else {
arg1_ty
};
let arg1_dtype =
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
unpack_ndarray_var_tys(self.unifier, arg1_ty).0
} else {
arg1_ty
};
let expected_arg1_dtype = if id == &"np_ldexp".into() {
self.primitives.int32
} else {
arg0_dtype
};
let expected_arg1_dtype =
if id == &"np_ldexp".into() { self.primitives.int32 } else { arg0_dtype };
if !self.unifier.unioned(arg1_dtype, expected_arg1_dtype) {
return report_error(
format!(
@ -995,7 +1008,7 @@ impl<'a> Inferencer<'a> {
self.unifier.stringify(arg1_dtype),
).as_str(),
arg0.location,
)
);
}
let target_ty = if id == &"np_minimum".into() || id == &"np_maximum".into() {
@ -1004,14 +1017,13 @@ impl<'a> Inferencer<'a> {
self.primitives.float
};
let ret = if [
&arg0_ty,
&arg1_ty,
].into_iter().any(|arg_ty| arg_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
let ret = if [&arg0_ty, &arg1_ty].into_iter().any(|arg_ty| {
arg_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
}) {
// typeof_ndarray_broadcast requires both dtypes to be the same, but ldexp accepts
// (float, int32), so convert it to align with the dtype of the first arg
let arg1_ty = if id == &"np_ldexp".into() {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
if arg1_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, arg1_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndims))
@ -1032,16 +1044,8 @@ impl<'a> Inferencer<'a> {
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "x1".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
FuncArg {
name: "x2".into(),
ty: arg1.custom.unwrap(),
default_value: None,
},
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
],
ret,
vars: VarMap::new(),
@ -1059,38 +1063,37 @@ impl<'a> Inferencer<'a> {
args: vec![arg0, arg1],
keywords: vec![],
},
}))
}));
}
// int64, uint32 and uint64 are special because their argument can be a constant outside the
// int64, uint32 and uint64 are special because their argument can be a constant outside the
// range of int32s
if [
"int64",
"uint32",
"uint64",
].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
if ["int64", "uint32", "uint64"].iter().any(|fun_id| id == &(*fun_id).into())
&& args.len() == 1
{
let target_ty = if id == &"int64".into() {
self.primitives.int64
} else if id == &"uint32".into() {
self.primitives.uint32
} else if id == &"uint64".into() {
self.primitives.uint64
} else { unreachable!() };
} else {
unreachable!()
};
// Handle constants first to ensure that their types are not defaulted to int32, which
// causes an "Integer out of bound" error
if let ExprKind::Constant {
value: ast::Constant::Int(val),
kind
} = &args[0].node {
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = &args[0].node {
let conv_is_ok = if self.unifier.unioned(target_ty, self.primitives.int64) {
i64::try_from(*val).is_ok()
} else if self.unifier.unioned(target_ty, self.primitives.uint32) {
u32::try_from(*val).is_ok()
} else if self.unifier.unioned(target_ty, self.primitives.uint64) {
u64::try_from(*val).is_ok()
} else { unreachable!() };
} else {
unreachable!()
};
return if conv_is_ok {
Ok(Some(Located {
location: args[0].location,
@ -1102,13 +1105,14 @@ impl<'a> Inferencer<'a> {
}))
} else {
report_error("Integer out of bound", args[0].location)
}
};
}
let arg0 = self.fold_expr(args.remove(0))?;
let arg0_ty = arg0.custom.unwrap();
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
{
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
make_ndarray_ty(self.unifier, self.primitives, Some(target_ty), Some(ndarray_ndims))
@ -1117,13 +1121,11 @@ impl<'a> Inferencer<'a> {
};
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
args: vec![FuncArg {
name: "n".into(),
ty: arg0.custom.unwrap(),
default_value: None,
}],
ret,
vars: VarMap::new(),
}));
@ -1140,30 +1142,29 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords: vec![],
},
}))
}));
}
// 1-argument ndarray n-dimensional creation functions
if [
"np_ndarray".into(),
"np_empty".into(),
"np_zeros".into(),
"np_ones".into(),
].contains(id) && args.len() == 1 {
if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()]
.contains(id)
&& args.len() == 1
{
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
args[0].location
)
format!(
"Expected List literal for first argument of {id}, got {}",
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
@ -1171,13 +1172,11 @@ impl<'a> Inferencer<'a> {
Some(ndims),
);
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
],
args: vec![FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
}],
ret,
vars: VarMap::new(),
}));
@ -1194,16 +1193,20 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords: vec![],
},
}))
}));
}
// 2-argument ndarray n-dimensional creation functions
if id == &"np_full".into() && args.len() == 2 {
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(),
args[0].location
)
format!(
"Expected List literal for first argument of {id}, got {}",
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let ndims = elts.len() as u64;
@ -1212,23 +1215,11 @@ impl<'a> Inferencer<'a> {
let arg1 = self.fold_expr(args.remove(0))?;
let ty = arg1.custom.unwrap();
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
);
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
default_value: None,
},
FuncArg { name: "shape".into(), ty: arg0.custom.unwrap(), default_value: None },
FuncArg {
name: "fill_value".into(),
ty: arg1.custom.unwrap(),
@ -1251,18 +1242,19 @@ impl<'a> Inferencer<'a> {
args: vec![arg0, arg1],
keywords: vec![],
},
}))
}));
}
// 1-argument ndarray n-dimensional creation functions
if id == &"np_array".into() && args.len() == 1 {
let arg0 = self.fold_expr(args.remove(0))?;
let keywords = keywords.iter()
let keywords = keywords
.iter()
.map(|v| fold::fold_keyword(self, v.clone()))
.collect::<Result<Vec<_>, _>>()?;
let ndmin_kw = keywords.iter()
.find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ndmin_kw =
keywords.iter().find(|kwarg| kwarg.node.arg.is_some_and(|id| id == "ndmin".into()));
let ty = arraylike_flatten_element_type(self.unifier, arg0.custom.unwrap());
let ndims = if let Some(ndmin_kw) = ndmin_kw {
@ -1270,30 +1262,22 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value, .. } => match value {
ast::Constant::Int(value) => *value as u64,
_ => return Err(HashSet::from(["Expected uint64 for ndims".to_string()])),
}
},
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
_ => arraylike_get_ndims(self.unifier, arg0.custom.unwrap()),
}
} else {
arraylike_get_ndims(self.unifier, arg0.custom.unwrap())
};
let ndims = self.unifier.get_fresh_literal(
vec![SymbolValue::U64(ndims)],
None,
);
let ret = make_ndarray_ty(
self.unifier,
self.primitives,
Some(ty),
Some(ndims),
);
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
FuncArg {
name: "object".into(),
ty: arg0.custom.unwrap(),
default_value: None
default_value: None,
},
FuncArg {
name: "copy".into(),
@ -1322,7 +1306,7 @@ impl<'a> Inferencer<'a> {
args: vec![arg0],
keywords,
},
}))
}));
}
Ok(None)
@ -1335,8 +1319,10 @@ impl<'a> Inferencer<'a> {
mut args: Vec<ast::Expr<()>>,
keywords: Vec<Located<ast::KeywordData>>,
) -> Result<ast::Expr<Option<Type>>, HashSet<String>> {
if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? {
return Ok(spec_call_func)
if let Some(spec_call_func) =
self.try_fold_special_call(location, &func, &mut args, &keywords)?
{
return Ok(spec_call_func);
}
let func = Box::new(self.fold_expr(func)?);
@ -1365,11 +1351,9 @@ impl<'a> Inferencer<'a> {
.map(|v| v.name)
.rev()
.collect();
self.unifier
.unify_call(&call, func.custom.unwrap(), sign, &required)
.map_err(|e| HashSet::from([
e.at(Some(location)).to_display(self.unifier).to_string(),
]))?;
self.unifier.unify_call(&call, func.custom.unwrap(), sign, &required).map_err(
|e| HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()]),
)?;
return Ok(Located {
location,
custom: Some(sign.ret),
@ -1403,8 +1387,7 @@ impl<'a> Inferencer<'a> {
} else {
let variable_mapping = &mut self.variable_mapping;
let unifier: &mut Unifier = self.unifier;
self
.function_data
self.function_data
.resolver
.get_symbol_type(unifier, &self.top_level.definitions.read(), self.primitives, id)
.unwrap_or_else(|_| {
@ -1434,8 +1417,9 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? }))
}
ast::Constant::Str(_) => Ok(self.primitives.str),
ast::Constant::None
=> report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc),
ast::Constant::None => {
report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc)
}
ast::Constant::Ellipsis => Ok(self.unifier.get_fresh_var(None, None).0),
_ => report_error("not supported", *loc),
}
@ -1471,8 +1455,11 @@ impl<'a> Inferencer<'a> {
}
(None, _) => {
let t = self.unifier.stringify(ty);
report_error(&format!("`{t}::{attr}` field/method does not exist"), value.location)
},
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
}
}
} else {
let attr_ty = self.unifier.get_dummy_var().0;
@ -1509,10 +1496,8 @@ impl<'a> Inferencer<'a> {
let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left_ty).as_ref()
{
let (binop_name, binop_assign_name) = (
binop_name(op).into(),
binop_assign_name(op).into()
);
let (binop_name, binop_assign_name) =
(binop_name(op).into(), binop_assign_name(op).into());
// if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) {
binop_assign_name
@ -1527,22 +1512,11 @@ impl<'a> Inferencer<'a> {
// The type of augmented assignment operator should never change
Some(left_ty)
} else {
typeof_binop(
self.unifier,
self.primitives,
op,
left_ty,
right_ty,
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
typeof_binop(self.unifier, self.primitives, op, left_ty, right_ty)
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
};
self.build_method_call(
location,
method,
left_ty,
vec![right_ty],
ret,
)
self.build_method_call(location, method, left_ty, vec![right_ty], ret)
}
fn infer_unary_ops(
@ -1553,12 +1527,8 @@ impl<'a> Inferencer<'a> {
) -> InferenceResult {
let method = unaryop_name(op).into();
let ret = typeof_unaryop(
self.unifier,
self.primitives,
op,
operand.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
let ret = typeof_unaryop(self.unifier, self.primitives, op, operand.custom.unwrap())
.map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
}
@ -1570,16 +1540,23 @@ impl<'a> Inferencer<'a> {
ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>],
) -> InferenceResult {
if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
if ops.len() > 1
&& once(left).chain(comparators).any(|expr| {
expr.custom
.unwrap()
.obj_id(self.unifier)
.is_some_and(|id| id == PrimDef::NDArray.id())
})
{
return Err(HashSet::from([String::from(
"Comparator chaining with ndarray types not supported",
)]));
}
let mut res = None;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(c)
.ok_or_else(|| HashSet::from([
"unsupported comparator".to_string()
]))?
.ok_or_else(|| HashSet::from(["unsupported comparator".to_string()]))?
.into();
let ret = typeof_cmpop(
@ -1588,7 +1565,8 @@ impl<'a> Inferencer<'a> {
c,
a.custom.unwrap(),
b.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
)
.map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
res.replace(self.build_method_call(
location,
@ -1614,28 +1592,29 @@ impl<'a> Inferencer<'a> {
TypeEnum::TVar { is_const_generic: false, .. }
));
let constrained_ty = make_ndarray_ty(
self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims),
);
let constrained_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims));
self.constrain(value.custom.unwrap(), constrained_ty, &value.location)?;
let TypeEnum::TLiteral { values, .. } = &*self.unifier.get_ty_immutable(ndims) else {
panic!("Expected TLiteral for ndarray.ndims, got {}", self.unifier.stringify(ndims))
};
let ndims = values.iter()
let ndims = values
.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| HashSet::from([
format!("Expected non-negative literal for ndarray.ndims, got {v}"),
])),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| HashSet::from([
format!("Expected non-negative literal for ndarray.ndims, got {v}"),
])),
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}"
)])
}),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}"
)])
}),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
@ -1684,32 +1663,35 @@ impl<'a> Inferencer<'a> {
}
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
}
_ => unreachable!()
_ => unreachable!(),
};
self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?;
Ok(list_like_ty)
}
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => {
// the index is a constant, so value can be a sequence.
let ind: Option<i32> = (*val).try_into().ok();
let ind = ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
let ind =
ind.ok_or_else(|| HashSet::from(["Index must be int32".to_string()]))?;
let map = once((
ind.into(),
RecordField::new(ty, ctx == &ExprContext::Store, Some(value.location)),
))
.collect();
.collect();
let seq = self.unifier.add_record(map);
self.constrain(value.custom.unwrap(), seq, &value.location)?;
Ok(ty)
@ -1717,54 +1699,67 @@ impl<'a> Inferencer<'a> {
}
}
ExprKind::Tuple { elts, .. } => {
if value.custom
if value
.custom
.unwrap()
.obj_id(self.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)
.not() {
return report_error("Tuple slices are only supported for ndarrays", slice.location)
.is_some_and(|id| id == PrimDef::NDArray.id())
.not()
{
return report_error(
"Tuple slices are only supported for ndarrays",
slice.location,
);
}
for elt in elts {
if let ExprKind::Slice { lower, upper, step } = &elt.node {
for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() {
self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?;
}
}
} else {
self.constrain(elt.custom.unwrap(), self.primitives.int32, &elt.location)?;
}
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndarray_ty = make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
let ndarray_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
Ok(ndarray_ty)
}
_ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location)
return report_error(
"Tuple index must be a constant (KernelInvariant is also not supported)",
slice.location,
);
}
// the index is not a constant, so value can only be a list-like structure
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => {
self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?;
self.constrain(
slice.custom.unwrap(),
self.primitives.int32,
&slice.location,
)?;
let list = self.unifier.add_ty(TypeEnum::TList { ty });
self.constrain(value.custom.unwrap(), list, &value.location)?;
Ok(ty)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let valid_index_tys = [
self.primitives.int32,
self.primitives.isize(),
].into_iter().unique().collect_vec();
let valid_index_ty = self.unifier.get_fresh_var_with_range(
valid_index_tys.as_slice(),
None,
None,
).0;
let valid_index_tys = [self.primitives.int32, self.primitives.isize()]
.into_iter()
.unique()
.collect_vec();
let valid_index_ty = self
.unifier
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
.0;
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims)
}

View File

@ -3,12 +3,12 @@ use super::*;
use crate::{
codegen::CodeGenContext,
symbol_resolver::ValueEnum,
toplevel::{DefinitionId, helper::PRIMITIVE_DEF_IDS, TopLevelDef},
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
};
use indoc::indoc;
use std::iter::zip;
use nac3parser::parser::parse_program;
use parking_lot::RwLock;
use std::iter::zip;
use test_case::test_case;
struct Resolver {
@ -44,7 +44,9 @@ impl SymbolResolver for Resolver {
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.id_to_def.get(&id).cloned()
self.id_to_def
.get(&id)
.cloned()
.ok_or_else(|| HashSet::from(["Unknown identifier".to_string()]))
}
@ -73,7 +75,7 @@ impl TestEnvironment {
let mut unifier = Unifier::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
obj_id: PrimDef::Int32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
@ -86,59 +88,60 @@ impl TestEnvironment {
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
obj_id: PrimDef::Int64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
obj_id: PrimDef::Float.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
obj_id: PrimDef::Bool.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
obj_id: PrimDef::None.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
obj_id: PrimDef::Range.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
obj_id: PrimDef::Str.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
obj_id: PrimDef::Exception.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
obj_id: PrimDef::Option.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray_ndims_tvar =
unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -211,7 +214,7 @@ impl TestEnvironment {
let mut identifier_mapping = HashMap::new();
let mut top_level_defs: Vec<Arc<RwLock<TopLevelDef>>> = Vec::new();
let int32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int32,
obj_id: PrimDef::Int32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
@ -224,57 +227,57 @@ impl TestEnvironment {
fields.insert("__add__".into(), (add_ty, false));
});
let int64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.int64,
obj_id: PrimDef::Int64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let float = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.float,
obj_id: PrimDef::Float.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let bool = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.bool,
obj_id: PrimDef::Bool.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let none = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.none,
obj_id: PrimDef::None.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let range = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.range,
obj_id: PrimDef::Range.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let str = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.str,
obj_id: PrimDef::Str.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let exception = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.exception,
obj_id: PrimDef::Exception.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let uint32 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint32,
obj_id: PrimDef::UInt32.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let uint64 = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.uint64,
obj_id: PrimDef::UInt64.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let option = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.option,
obj_id: PrimDef::Option.id(),
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
obj_id: PrimDef::NDArray.id(),
fields: HashMap::new(),
params: VarMap::new(),
});

View File

@ -1,12 +1,12 @@
use indexmap::IndexMap;
use itertools::Itertools;
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Display;
use std::iter::zip;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::{borrow::Cow, collections::HashSet};
use std::iter::zip;
use indexmap::IndexMap;
use itertools::Itertools;
use nac3parser::ast::{Location, StrRef};
@ -61,7 +61,7 @@ pub enum RecordKey {
}
impl Type {
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`.
#[must_use]
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
@ -250,9 +250,9 @@ impl Unifier {
}
/// Returns the [`UnificationTable`] associated with this `Unifier`.
///
///
/// # Safety
///
///
/// The use of this function is discouraged under most circumstances. Only use this function if
/// in-place manipulation of type variables and/or type fields is necessary, otherwise prefer to
/// [add a new type][`Unifier::add_ty`] and [unify the type][`Unifier::unify`] with an existing
@ -379,7 +379,17 @@ impl Unifier {
let id = self.var_id + 1;
self.var_id += 1;
let range = range.to_vec();
(self.add_ty(TypeEnum::TVar { id, range, fields: None, name, loc, is_const_generic: false }), id)
(
self.add_ty(TypeEnum::TVar {
id,
range,
fields: None,
name,
loc,
is_const_generic: false,
}),
id,
)
}
/// Returns a fresh type representing a constant generic variable with the given underlying type `ty`.
@ -391,19 +401,22 @@ impl Unifier {
) -> (Type, u32) {
let id = self.var_id + 1;
self.var_id += 1;
(self.add_ty(TypeEnum::TVar { id, range: vec![ty], fields: None, name, loc, is_const_generic: true }), id)
(
self.add_ty(TypeEnum::TVar {
id,
range: vec![ty],
fields: None,
name,
loc,
is_const_generic: true,
}),
id,
)
}
/// Returns a fresh type representing a [literal][TypeEnum::TConstant] with the given `values`.
pub fn get_fresh_literal(
&mut self,
values: Vec<SymbolValue>,
loc: Option<Location>,
) -> Type {
let ty_enum = TypeEnum::TLiteral {
values: values.into_iter().dedup().collect(),
loc
};
pub fn get_fresh_literal(&mut self, values: Vec<SymbolValue>, loc: Option<Location>) -> Type {
let ty_enum = TypeEnum::TLiteral { values: values.into_iter().dedup().collect(), loc };
self.add_ty(ty_enum)
}
@ -423,7 +436,9 @@ impl Unifier {
Some(
range
.iter()
.flat_map(|ty| self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty]))
.flat_map(|ty| {
self.get_instantiations(*ty).unwrap_or_else(|| vec![*ty])
})
.collect_vec(),
)
}
@ -479,7 +494,7 @@ impl Unifier {
pub fn is_concrete(&mut self, a: Type, allowed_typevars: &[Type]) -> bool {
use TypeEnum::*;
match &*self.get_ty(a) {
TRigidVar { .. }
TRigidVar { .. }
| TLiteral { .. }
// functions are instantiated for each call sites, so the function type can contain
// type variables.
@ -487,7 +502,7 @@ impl Unifier {
TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)),
TCall { .. } => false,
TList { ty }
TList { ty }
| TVirtual { ty } => self.is_concrete(*ty, allowed_typevars),
TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)),
@ -526,9 +541,7 @@ impl Unifier {
let instantiated = self.instantiate_fun(b, signature);
let r = self.get_ty(instantiated);
let r = r.as_ref();
let TypeEnum::TFunc(signature) = r else {
unreachable!()
};
let TypeEnum::TFunc(signature) = r else { unreachable!() };
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec();
@ -555,13 +568,10 @@ impl Unifier {
if let Some(i) = required.iter().position(|v| v == k) {
required.remove(i);
}
let i = all_names
.iter()
.position(|v| &v.0 == k)
.ok_or_else(|| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let i = all_names.iter().position(|v| &v.0 == k).ok_or_else(|| {
self.restore_snapshot();
TypeError::new(TypeErrorKind::UnknownArgName(*k), *loc)
})?;
let (name, expected) = all_names.remove(i);
self.unify_impl(expected, *t, false).map_err(|_| {
self.restore_snapshot();
@ -627,8 +637,17 @@ impl Unifier {
};
match (&*ty_a, &*ty_b) {
(
TVar { fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, .. },
TVar { fields: fields2, id: id2, name: name2, loc: loc2, is_const_generic: false, .. },
TVar {
fields: fields1, id, name: name1, loc: loc1, is_const_generic: false, ..
},
TVar {
fields: fields2,
id: id2,
name: name2,
loc: loc2,
is_const_generic: false,
..
},
) => {
let new_fields = match (fields1, fields2) {
(None, None) => None,
@ -750,7 +769,10 @@ impl Unifier {
self.set_a_to_b(a, x);
}
(TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }) => {
(
TVar { id: id1, range: ty1, is_const_generic: true, .. },
TVar { id: id2, range: ty2, .. },
) => {
let ty1 = ty1[0];
let ty2 = ty2[0];
@ -765,17 +787,17 @@ impl Unifier {
assert_eq!(tys.len(), 1);
assert_eq!(values.len(), 1);
let primitives = &self.primitive_store
.expect("Expected PrimitiveStore to be present");
let primitives =
&self.primitive_store.expect("Expected PrimitiveStore to be present");
let ty = tys[0];
let value= &values[0];
let value = &values[0];
let value_ty = value.get_type(primitives, self);
// If the types don't match, try to implicitly promote integers
if !self.unioned(ty, value_ty) {
let Ok(num_val) = i128::try_from(value.clone()) else {
return Self::incompatible_types(a, b)
return Self::incompatible_types(a, b);
};
let can_convert = if self.unioned(ty, primitives.int32) {
@ -791,7 +813,7 @@ impl Unifier {
};
if !can_convert {
return Self::incompatible_types(a, b)
return Self::incompatible_types(a, b);
}
}
@ -816,7 +838,7 @@ impl Unifier {
let v2i = symbol_value_to_int(v2);
if v1i != v2i {
return Self::incompatible_types(a, b)
return Self::incompatible_types(a, b);
}
}
}
@ -1287,8 +1309,8 @@ impl Unifier {
mapping: &VarMap,
cache: &mut HashMap<Type, Option<Type>>,
) -> Option<IndexMapping<K>>
where
K: std::hash::Hash + Eq + Clone,
where
K: std::hash::Hash + Eq + Clone,
{
let mut map2 = None;
for (k, v) in map {

View File

@ -45,9 +45,9 @@ impl Unifier {
}
}
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where
K: std::hash::Hash + Eq + Clone
fn map_eq<K>(&mut self, map1: &IndexMapping<K>, map2: &IndexMapping<K>) -> bool
where
K: std::hash::Hash + Eq + Clone,
{
if map1.len() != map2.len() {
return false;
@ -342,16 +342,12 @@ fn test_recursive_subst() {
with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true));
});
let TypeEnum::TObj { params, .. } = &*foo_ty else {
unreachable!()
};
let TypeEnum::TObj { params, .. } = &*foo_ty else { unreachable!() };
let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated);
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else {
unreachable!()
};
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else { unreachable!() };
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
}
@ -477,7 +473,8 @@ fn test_typevar_range() {
assert_eq!(
env.unify(a_list, int_list),
Err("Incompatible types: list[typevar22] and list[0]\
\n\nNotes:\n typevar22 {1}".into())
\n\nNotes:\n typevar22 {1}"
.into())
);
let a = env.unifier.get_fresh_var_with_range(&[int, float], None, None).0;
@ -505,7 +502,10 @@ fn test_rigid_var() {
assert_eq!(env.unify(a, b), Err("Incompatible types: typevar3 and typevar2".to_string()));
env.unifier.unify(list_a, list_x).unwrap();
assert_eq!(env.unify(list_x, list_int), Err("Incompatible types: list[typevar2] and list[0]".to_string()));
assert_eq!(
env.unify(list_x, list_int),
Err("Incompatible types: list[typevar2] and list[0]".to_string())
);
env.unifier.replace_rigid_var(a, int);
env.unifier.unify(list_x, list_int).unwrap();

View File

@ -16,21 +16,10 @@ pub struct UnificationTable<V> {
#[derive(Clone, Debug)]
enum Action<V> {
Parent {
key: usize,
original_parent: usize,
},
Value {
key: usize,
original_value: Option<V>,
},
Rank {
key: usize,
original_rank: u32,
},
Marker {
generation: u32,
}
Parent { key: usize, original_parent: usize },
Value { key: usize, original_value: Option<V> },
Rank { key: usize, original_rank: u32 },
Marker { generation: u32 },
}
impl<V> Default for UnificationTable<V> {
@ -41,7 +30,13 @@ impl<V> Default for UnificationTable<V> {
impl<V> UnificationTable<V> {
pub fn new() -> UnificationTable<V> {
UnificationTable { parents: Vec::new(), ranks: Vec::new(), values: Vec::new(), log: Vec::new(), generation: 0 }
UnificationTable {
parents: Vec::new(),
ranks: Vec::new(),
values: Vec::new(),
log: Vec::new(),
generation: 0,
}
}
pub fn new_key(&mut self, v: V) -> UnificationKey {
@ -125,7 +120,10 @@ impl<V> UnificationTable<V> {
pub fn restore_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot restoration error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot restoration error");
assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot restoration error"
);
for action in self.log.drain(log_len - 1..).rev() {
match action {
Action::Parent { key, original_parent } => {
@ -145,7 +143,10 @@ impl<V> UnificationTable<V> {
pub fn discard_snapshot(&mut self, snapshot: (usize, u32)) {
let (log_len, generation) = snapshot;
assert!(self.log.len() >= log_len, "snapshot discard error");
assert!(matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation), "snapshot discard error");
assert!(
matches!(self.log[log_len - 1], Action::Marker { generation: gen } if gen == generation),
"snapshot discard error"
);
self.log.clear();
}
}
@ -159,11 +160,23 @@ where
.enumerate()
.map(|(i, (v, p))| if *p == i { v.as_ref().map(|v| v.as_ref().clone()) } else { None })
.collect();
UnificationTable { parents: self.parents.clone(), ranks: self.ranks.clone(), values, log: Vec::new(), generation: 0 }
UnificationTable {
parents: self.parents.clone(),
ranks: self.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
}
pub fn from_send(table: &UnificationTable<V>) -> UnificationTable<Rc<V>> {
let values = table.values.iter().cloned().map(|v| v.map(Rc::new)).collect();
UnificationTable { parents: table.parents.clone(), ranks: table.ranks.clone(), values, log: Vec::new(), generation: 0 }
UnificationTable {
parents: table.parents.clone(),
ranks: table.ranks.clone(),
values,
log: Vec::new(),
generation: 0,
}
}
}

View File

@ -32,7 +32,6 @@ pub struct DwarfReader<'a> {
}
impl<'a> DwarfReader<'a> {
pub fn new(slice: &[u8], virt_addr: u32) -> DwarfReader {
DwarfReader { slice, virt_addr, base_slice: slice, base_virt_addr: virt_addr }
}
@ -170,10 +169,7 @@ fn read_encoded_pointer(reader: &mut DwarfReader, encoding: u8) -> Result<usize,
}
}
fn read_encoded_pointer_with_pc(
reader: &mut DwarfReader,
encoding: u8,
) -> Result<usize, ()> {
fn read_encoded_pointer_with_pc(reader: &mut DwarfReader, encoding: u8) -> Result<usize, ()> {
let entry_virt_addr = reader.virt_addr;
let mut result = read_encoded_pointer(reader, encoding)?;
@ -223,7 +219,6 @@ pub struct EH_Frame<'a> {
}
impl<'a> EH_Frame<'a> {
/// Creates an [EH_Frame] using the bytes in the `.eh_frame` section and its address in the ELF
/// file.
pub fn new(eh_frame_slice: &[u8], eh_frame_addr: u32) -> Result<EH_Frame, ()> {
@ -235,10 +230,7 @@ impl<'a> EH_Frame<'a> {
let reader = DwarfReader::from_reader(&self.reader, true);
let len = reader.slice.len();
CFI_Records {
reader,
available: len,
}
CFI_Records { reader, available: len }
}
}
@ -255,7 +247,6 @@ pub struct CFI_Record<'a> {
}
impl<'a> CFI_Record<'a> {
pub fn from_reader(cie_reader: &mut DwarfReader<'a>) -> Result<CFI_Record<'a>, ()> {
let length = cie_reader.read_u32();
let fde_reader = match length {
@ -323,10 +314,7 @@ impl<'a> CFI_Record<'a> {
}
assert_ne!(fde_pointer_encoding, DW_EH_PE_omit);
Ok(CFI_Record {
fde_pointer_encoding,
fde_reader,
})
Ok(CFI_Record { fde_pointer_encoding, fde_reader })
}
/// Returns a [DwarfReader] initialized to the first Frame Description Entry (FDE) of this CFI
@ -340,11 +328,7 @@ impl<'a> CFI_Record<'a> {
let reader = self.get_fde_reader();
let len = reader.slice.len();
FDE_Records {
pointer_encoding: self.fde_pointer_encoding,
reader,
available: len,
}
FDE_Records { pointer_encoding: self.fde_pointer_encoding, reader, available: len }
}
}
@ -387,7 +371,7 @@ impl<'a> Iterator for CFI_Records<'a> {
// Skip this record if it is a FDE
if cie_ptr == 0 {
// Rewind back to the start of the CFI Record
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap())
return Some(CFI_Record::from_reader(&mut this_reader).ok().unwrap());
}
}
}
@ -448,7 +432,6 @@ pub struct EH_Frame_Hdr<'a> {
}
impl<'a> EH_Frame_Hdr<'a> {
/// Create a [EH_Frame_Hdr] object, and write out the fixed fields of `.eh_frame_hdr` to memory.
///
/// Load address is not known at this point.
@ -459,15 +442,16 @@ impl<'a> EH_Frame_Hdr<'a> {
) -> EH_Frame_Hdr {
let mut writer = DwarfWriter::new(eh_frame_hdr_slice);
writer.write_u8(1); // version
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
writer.write_u8(1); // version
writer.write_u8(0x1B); // eh_frame_ptr_enc - PC-relative 4-byte signed value
writer.write_u8(0x03); // fde_count_enc - 4-byte unsigned value
writer.write_u8(0x3B); // table_enc - .eh_frame_hdr section-relative 4-byte signed value
let eh_frame_offset = eh_frame_addr
.wrapping_sub(eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4));
writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
let eh_frame_offset = eh_frame_addr.wrapping_sub(
eh_frame_hdr_addr + writer.offset as u32 + ((mem::size_of::<u8>() as u32) * 4),
);
writer.write_u32(eh_frame_offset); // eh_frame_ptr
writer.write_u32(0); // `fde_count`, will be written in finalize_fde
EH_Frame_Hdr { fde_writer: writer, eh_frame_hdr_addr, fdes: Vec::new() }
}
@ -492,7 +476,10 @@ impl<'a> EH_Frame_Hdr<'a> {
self.fde_writer.write_u32(*init_loc);
self.fde_writer.write_u32(*addr);
}
LittleEndian::write_u32(&mut self.fde_writer.slice[Self::fde_count_offset()..], self.fdes.len() as u32);
LittleEndian::write_u32(
&mut self.fde_writer.slice[Self::fde_count_offset()..],
self.fdes.len() as u32,
);
}
pub fn size_from_eh_frame(eh_frame: &[u8]) -> usize {

View File

@ -205,11 +205,9 @@ impl<'a> Linker<'a> {
for reloc in relocs {
let sym = match reloc.sym_info() as usize {
STN_UNDEF => None,
sym_index => Some(
self.symtab
.get(sym_index)
.ok_or("symbol out of bounds of symbol table")?,
),
sym_index => {
Some(self.symtab.get(sym_index).ok_or("symbol out of bounds of symbol table")?)
}
};
let resolve_symbol_addr =
@ -314,9 +312,8 @@ impl<'a> Linker<'a> {
R_RISCV_PCREL_LO12_I => {
let expected_offset = sym_option.map_or(0, |sym| sym.st_value);
let indirect_reloc = relocs
.iter()
.find(|reloc| reloc.offset() == expected_offset)?;
let indirect_reloc =
relocs.iter().find(|reloc| reloc.offset() == expected_offset)?;
Some(RelocInfo {
defined_val: {
let indirect_sym =
@ -354,10 +351,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u32(
target_word,
value,
)
LittleEndian::write_u32(target_word, value)
})),
}),
@ -386,10 +380,7 @@ impl<'a> Linker<'a> {
indirect_reloc: None,
pc_relative: false,
relocate: Some(Box::new(|target_word, value| {
LittleEndian::write_u16(
target_word,
value as u16,
)
LittleEndian::write_u16(target_word, value as u16)
})),
}),
@ -552,9 +543,12 @@ impl<'a> Linker<'a> {
eh_frame_hdr_rec.shdr.sh_offset,
eh_frame_rec.shdr.sh_offset,
);
eh_frame.cfi_records()
.flat_map(|cfi| cfi.fde_records())
.for_each(&mut |(init_pos, virt_addr)| eh_frame_hdr.add_fde(init_pos, virt_addr));
eh_frame.cfi_records().flat_map(|cfi| cfi.fde_records()).for_each(&mut |(
init_pos,
virt_addr,
)| {
eh_frame_hdr.add_fde(init_pos, virt_addr)
});
// Sort FDE entries in .eh_frame_hdr
eh_frame_hdr.finalize_fde();
@ -599,24 +593,22 @@ impl<'a> Linker<'a> {
// Section table for the .elf paired with the section name
// To be formalized incrementally
// Very hashmap-like structure, but the order matters, so it is a vector
let elf_shdrs = vec![
SectionRecord {
shdr: Elf32_Shdr {
sh_name: 0,
sh_type: 0,
sh_flags: 0,
sh_addr: 0,
sh_offset: 0,
sh_size: 0,
sh_link: 0,
sh_info: 0,
sh_addralign: 0,
sh_entsize: 0,
},
name: "",
data: vec![0; 0],
let elf_shdrs = vec![SectionRecord {
shdr: Elf32_Shdr {
sh_name: 0,
sh_type: 0,
sh_flags: 0,
sh_addr: 0,
sh_offset: 0,
sh_size: 0,
sh_link: 0,
sh_info: 0,
sh_addralign: 0,
sh_entsize: 0,
},
];
name: "",
data: vec![0; 0],
}];
let elf_sh_data_off = mem::size_of::<Elf32_Ehdr>() + mem::size_of::<Elf32_Phdr>() * 5;
// Image of the linked dynamic library, to be formalized incrementally
@ -1010,7 +1002,9 @@ impl<'a> Linker<'a> {
let mut hash_bucket: Vec<u32> = vec![0; dynsym.len()];
let mut hash_chain: Vec<u32> = vec![0; dynsym.len()];
for (sym_index, (str_start, str_end)) in dynsym_names.iter().enumerate().take(dynsym.len()).skip(1) {
for (sym_index, (str_start, str_end)) in
dynsym_names.iter().enumerate().take(dynsym.len()).skip(1)
{
let hash = elf_hash(&dynstr[*str_start..*str_end]);
let mut hash_index = hash as usize % hash_bucket.len();
@ -1253,7 +1247,9 @@ impl<'a> Linker<'a> {
update_dynsym_record!(b"__bss_start", bss_offset, bss_elf_index as Elf32_Section);
update_dynsym_record!(b"_end", bss_offset, bss_elf_index as Elf32_Section);
} else {
for (bss_iter_index, &(bss_section_index, section_name)) in bss_index_vec.iter().enumerate() {
for (bss_iter_index, &(bss_section_index, section_name)) in
bss_index_vec.iter().enumerate()
{
let shdr = &shdrs[bss_section_index];
let bss_elf_index = linker.load_section(
shdr,

View File

@ -1,15 +1,15 @@
use lalrpop_util::ParseError;
use nac3ast::*;
use crate::ast::Ident;
use crate::ast::Location;
use crate::token::Tok;
use crate::error::*;
use crate::token::Tok;
use lalrpop_util::ParseError;
use nac3ast::*;
pub fn make_config_comment(
com_loc: Location,
stmt_loc: Location,
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>
nac3com_end: Option<Ident>,
) -> Result<Vec<Ident>, ParseError<Location, Tok, LexicalError>> {
if com_loc.column() != stmt_loc.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
@ -23,18 +23,21 @@ pub fn make_config_comment(
)
)
}
})
});
};
Ok(
nac3com_above
.into_iter()
.map(|(com, _)| com)
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.collect()
)
Ok(nac3com_above
.into_iter()
.map(|(com, _)| com)
.chain(nac3com_end.map_or_else(|| vec![].into_iter(), |com| vec![com].into_iter()))
.collect())
}
pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, Tok)>, nac3com_end: Option<Ident>, com_above_loc: Location) -> Result<(), ParseError<Location, Tok, LexicalError>> {
pub fn handle_small_stmt<U>(
stmts: &mut [Stmt<U>],
nac3com_above: Vec<(Ident, Tok)>,
nac3com_end: Option<Ident>,
com_above_loc: Location,
) -> Result<(), ParseError<Location, Tok, LexicalError>> {
if com_above_loc.column() != stmts[0].location.column() && !nac3com_above.is_empty() {
return Err(ParseError::User {
error: LexicalError {
@ -47,17 +50,12 @@ pub fn handle_small_stmt<U>(stmts: &mut [Stmt<U>], nac3com_above: Vec<(Ident, To
)
)
}
})
});
}
apply_config_comments(
&mut stmts[0],
nac3com_above
.into_iter()
.map(|(com, _)| com).collect()
);
apply_config_comments(&mut stmts[0], nac3com_above.into_iter().map(|(com, _)| com).collect());
apply_config_comments(
stmts.last_mut().unwrap(),
nac3com_end.map_or_else(Vec::new, |com| vec![com])
nac3com_end.map_or_else(Vec::new, |com| vec![com]),
);
Ok(())
}
@ -72,7 +70,7 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::AnnAssign { config_comment, .. }
| StmtKind::Break { config_comment, .. }
| StmtKind::Continue { config_comment, .. }
| StmtKind::Return { config_comment, .. }
| StmtKind::Return { config_comment, .. }
| StmtKind::Raise { config_comment, .. }
| StmtKind::Import { config_comment, .. }
| StmtKind::ImportFrom { config_comment, .. }
@ -80,6 +78,8 @@ fn apply_config_comments<U>(stmt: &mut Stmt<U>, comments: Vec<Ident>) {
| StmtKind::Nonlocal { config_comment, .. }
| StmtKind::Assert { config_comment, .. } => config_comment.extend(comments),
_ => { unreachable!("only small statements should call this function") }
_ => {
unreachable!("only small statements should call this function")
}
}
}

View File

@ -145,35 +145,27 @@ impl From<LalrpopError<Location, Tok, LexicalError>> for ParseError {
fn from(err: LalrpopError<Location, Tok, LexicalError>) -> Self {
match err {
// TODO: Are there cases where this isn't an EOF?
LalrpopError::InvalidToken { location } => ParseError {
error: ParseErrorType::Eof,
location,
},
LalrpopError::ExtraToken { token } => ParseError {
error: ParseErrorType::ExtraToken(token.1),
location: token.0,
},
LalrpopError::User { error } => ParseError {
error: ParseErrorType::Lexical(error.error),
location: error.location,
},
LalrpopError::InvalidToken { location } => {
ParseError { error: ParseErrorType::Eof, location }
}
LalrpopError::ExtraToken { token } => {
ParseError { error: ParseErrorType::ExtraToken(token.1), location: token.0 }
}
LalrpopError::User { error } => {
ParseError { error: ParseErrorType::Lexical(error.error), location: error.location }
}
LalrpopError::UnrecognizedToken { token, expected } => {
// Hacky, but it's how CPython does it. See PyParser_AddToken,
// in particular "Only one possible expected token" comment.
let expected = if expected.len() == 1 {
Some(expected[0].clone())
} else {
None
};
let expected = if expected.len() == 1 { Some(expected[0].clone()) } else { None };
ParseError {
error: ParseErrorType::UnrecognizedToken(token.1, expected),
location: token.0,
}
}
LalrpopError::UnrecognizedEof { location, .. } => ParseError {
error: ParseErrorType::Eof,
location,
},
LalrpopError::UnrecognizedEof { location, .. } => {
ParseError { error: ParseErrorType::Eof, location }
}
}
}
}

View File

@ -15,10 +15,7 @@ struct FStringParser<'a> {
impl<'a> FStringParser<'a> {
fn new(source: &'a str, str_location: Location) -> Self {
Self {
chars: source.chars().peekable(),
str_location,
}
Self { chars: source.chars().peekable(), str_location }
}
#[inline]
@ -251,17 +248,11 @@ impl<'a> FStringParser<'a> {
}
if !content.is_empty() {
values.push(self.expr(ExprKind::Constant {
value: content.into(),
kind: None,
}))
values.push(self.expr(ExprKind::Constant { value: content.into(), kind: None }))
}
let s = match values.len() {
0 => self.expr(ExprKind::Constant {
value: String::new().into(),
kind: None,
}),
0 => self.expr(ExprKind::Constant { value: String::new().into(), kind: None }),
1 => values.into_iter().next().unwrap(),
_ => self.expr(ExprKind::JoinedStr { values }),
};
@ -277,9 +268,7 @@ fn parse_fstring_expr(source: &str) -> Result<Expr, ParseError> {
/// Parse an fstring from a string, located at a certain position in the sourcecode.
/// In case of errors, we will get the location and the error returned.
pub fn parse_located_fstring(source: &str, location: Location) -> Result<Expr, FStringError> {
FStringParser::new(source, location)
.parse()
.map_err(|error| FStringError { error, location })
FStringParser::new(source, location).parse().map_err(|error| FStringError { error, location })
}
#[cfg(test)]

View File

@ -69,10 +69,7 @@ pub fn parse_args(func_args: Vec<FunctionArgument>) -> Result<ArgumentList, Lexi
keywords.push(ast::Keyword::new(
location,
ast::KeywordData {
arg: name.map(|name| name.into()),
value: Box::new(value),
},
ast::KeywordData { arg: name.map(|name| name.into()), value: Box::new(value) },
));
}
None => {

View File

@ -3,12 +3,12 @@
//! This means source code is translated into separate tokens.
pub use super::token::Tok;
use crate::ast::{Location, FileName};
use crate::ast::{FileName, Location};
use crate::error::{LexicalError, LexicalErrorType};
use std::char;
use std::cmp::Ordering;
use std::str::FromStr;
use std::num::IntErrorKind;
use std::str::FromStr;
use unic_emoji_char::is_emoji_presentation;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
@ -32,20 +32,14 @@ impl IndentationLevel {
if self.spaces <= other.spaces {
Ok(Ordering::Less)
} else {
Err(LexicalError {
location,
error: LexicalErrorType::TabError,
})
Err(LexicalError { location, error: LexicalErrorType::TabError })
}
}
Ordering::Greater => {
if self.spaces >= other.spaces {
Ok(Ordering::Greater)
} else {
Err(LexicalError {
location,
error: LexicalErrorType::TabError,
})
Err(LexicalError { location, error: LexicalErrorType::TabError })
}
}
Ordering::Equal => Ok(self.spaces.cmp(&other.spaces)),
@ -63,7 +57,7 @@ pub struct Lexer<T: Iterator<Item = char>> {
chr1: Option<char>,
chr2: Option<char>,
location: Location,
config_comment_prefix: Option<&'static str>
config_comment_prefix: Option<&'static str>,
}
pub static KEYWORDS: phf::Map<&'static str, Tok> = phf::phf_map! {
@ -136,11 +130,7 @@ where
T: Iterator<Item = char>,
{
pub fn new(source: T) -> Self {
let mut nlh = NewlineHandler {
source,
chr0: None,
chr1: None,
};
let mut nlh = NewlineHandler { source, chr0: None, chr1: None };
nlh.shift();
nlh.shift();
nlh
@ -195,7 +185,7 @@ where
location: start,
chr1: None,
chr2: None,
config_comment_prefix: Some(" nac3:")
config_comment_prefix: Some(" nac3:"),
};
lxr.next_char();
lxr.next_char();
@ -287,15 +277,15 @@ where
let end_pos = self.get_pos();
let value = match i128::from_str_radix(&value_text, radix) {
Ok(value) => value,
Err(e) => {
match e.kind() {
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
_ => return Err(LexicalError {
Err(e) => match e.kind() {
IntErrorKind::PosOverflow | IntErrorKind::NegOverflow => i128::MAX,
_ => {
return Err(LexicalError {
error: LexicalErrorType::OtherError(format!("{:?}", e)),
location: start_pos,
}),
})
}
}
},
};
Ok((start_pos, Tok::Int { value }, end_pos))
}
@ -338,14 +328,7 @@ where
if self.chr0 == Some('j') || self.chr0 == Some('J') {
self.next_char();
let end_pos = self.get_pos();
Ok((
start_pos,
Tok::Complex {
real: 0.0,
imag: value,
},
end_pos,
))
Ok((start_pos, Tok::Complex { real: 0.0, imag: value }, end_pos))
} else {
let end_pos = self.get_pos();
Ok((start_pos, Tok::Float { value }, end_pos))
@ -364,7 +347,7 @@ where
let value = value_text.parse::<i128>().ok();
let nonzero = match value {
Some(value) => value != 0i128,
None => true
None => true,
};
if start_is_zero && nonzero {
return Err(LexicalError {
@ -433,9 +416,8 @@ where
fn lex_comment(&mut self) -> Option<Spanned> {
self.next_char();
// if possibly nac3 pseudocomment, special handling for `# nac3:`
let (mut prefix, mut is_comment) = self
.config_comment_prefix
.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
let (mut prefix, mut is_comment) =
self.config_comment_prefix.map_or_else(|| ("".chars(), false), |v| (v.chars(), true));
// for the correct location of config comment
let mut start_loc = self.location;
start_loc.go_left();
@ -460,22 +442,20 @@ where
return Some((
start_loc,
Tok::ConfigComment { content: content.trim().into() },
self.location
self.location,
));
}
}
}
}
self.next_char();
};
}
}
fn unicode_literal(&mut self, literal_number: usize) -> Result<char, LexicalError> {
let mut p: u32 = 0u32;
let unicode_error = LexicalError {
error: LexicalErrorType::UnicodeError,
location: self.get_pos(),
};
let unicode_error =
LexicalError { error: LexicalErrorType::UnicodeError, location: self.get_pos() };
for i in 1..=literal_number {
match self.next_char() {
Some(c) => match c.to_digit(16) {
@ -530,10 +510,8 @@ where
}
}
}
unicode_names2::character(&name).ok_or(LexicalError {
error: LexicalErrorType::UnicodeError,
location: start_pos,
})
unicode_names2::character(&name)
.ok_or(LexicalError { error: LexicalErrorType::UnicodeError, location: start_pos })
}
fn lex_string(
@ -650,14 +628,9 @@ where
let end_pos = self.get_pos();
let tok = if is_bytes {
Tok::Bytes {
value: string_content.chars().map(|c| c as u8).collect(),
}
Tok::Bytes { value: string_content.chars().map(|c| c as u8).collect() }
} else {
Tok::String {
value: string_content,
is_fstring,
}
Tok::String { value: string_content, is_fstring }
};
Ok((start_pos, tok, end_pos))
@ -842,11 +815,7 @@ where
let tok_start = self.get_pos();
self.next_char();
let tok_end = self.get_pos();
self.emit((
tok_start,
Tok::Name { name: c.to_string().into() },
tok_end,
));
self.emit((tok_start, Tok::Name { name: c.to_string().into() }, tok_end));
} else {
self.consume_character(c)?;
}
@ -1439,14 +1408,8 @@ class Foo(A, B):
assert_eq!(
tokens,
vec![
Tok::String {
value: "\\\\".to_owned(),
is_fstring: false,
},
Tok::String {
value: "\\".to_owned(),
is_fstring: false,
},
Tok::String { value: "\\\\".to_owned(), is_fstring: false },
Tok::String { value: "\\".to_owned(), is_fstring: false },
Tok::Newline,
]
);
@ -1459,27 +1422,13 @@ class Foo(A, B):
assert_eq!(
tokens,
vec![
Tok::Int {
value: 47i128,
},
Tok::Int {
value: 13i128,
},
Tok::Int {
value: 0i128,
},
Tok::Int {
value: 123i128,
},
Tok::Int { value: 47i128 },
Tok::Int { value: 13i128 },
Tok::Int { value: 0i128 },
Tok::Int { value: 123i128 },
Tok::Float { value: 0.2 },
Tok::Complex {
real: 0.0,
imag: 2.0,
},
Tok::Complex {
real: 0.0,
imag: 2.2,
},
Tok::Complex { real: 0.0, imag: 2.0 },
Tok::Complex { real: 0.0, imag: 2.2 },
Tok::Newline,
]
);
@ -1539,21 +1488,13 @@ class Foo(A, B):
assert_eq!(
tokens,
vec![
Tok::Name {
name: String::from("avariable").into(),
},
Tok::Name { name: String::from("avariable").into() },
Tok::Equal,
Tok::Int {
value: 99i128
},
Tok::Int { value: 99i128 },
Tok::Plus,
Tok::Int {
value: 2i128
},
Tok::Int { value: 2i128 },
Tok::Minus,
Tok::Int {
value: 0i128
},
Tok::Int { value: 0i128 },
Tok::Newline,
]
);
@ -1740,42 +1681,15 @@ class Foo(A, B):
assert_eq!(
tokens,
vec![
Tok::String {
value: String::from("double"),
is_fstring: false,
},
Tok::String {
value: String::from("single"),
is_fstring: false,
},
Tok::String {
value: String::from("can't"),
is_fstring: false,
},
Tok::String {
value: String::from("\\\""),
is_fstring: false,
},
Tok::String {
value: String::from("\t\r\n"),
is_fstring: false,
},
Tok::String {
value: String::from("\\g"),
is_fstring: false,
},
Tok::String {
value: String::from("raw\\'"),
is_fstring: false,
},
Tok::String {
value: String::from("Đ"),
is_fstring: false,
},
Tok::String {
value: String::from("\u{80}\u{0}a"),
is_fstring: false,
},
Tok::String { value: String::from("double"), is_fstring: false },
Tok::String { value: String::from("single"), is_fstring: false },
Tok::String { value: String::from("can't"), is_fstring: false },
Tok::String { value: String::from("\\\""), is_fstring: false },
Tok::String { value: String::from("\t\r\n"), is_fstring: false },
Tok::String { value: String::from("\\g"), is_fstring: false },
Tok::String { value: String::from("raw\\'"), is_fstring: false },
Tok::String { value: String::from("Đ"), is_fstring: false },
Tok::String { value: String::from("\u{80}\u{0}a"), is_fstring: false },
Tok::Newline,
]
);
@ -1840,41 +1754,17 @@ class Foo(A, B):
fn test_raw_byte_literal() {
let source = r"rb'\x1z'";
let tokens = lex_source(source);
assert_eq!(
tokens,
vec![
Tok::Bytes {
value: b"\\x1z".to_vec()
},
Tok::Newline
]
);
assert_eq!(tokens, vec![Tok::Bytes { value: b"\\x1z".to_vec() }, Tok::Newline]);
let source = r"rb'\\'";
let tokens = lex_source(source);
assert_eq!(
tokens,
vec![
Tok::Bytes {
value: b"\\\\".to_vec()
},
Tok::Newline
]
)
assert_eq!(tokens, vec![Tok::Bytes { value: b"\\\\".to_vec() }, Tok::Newline])
}
#[test]
fn test_escape_octet() {
let source = r##"b'\43a\4\1234'"##;
let tokens = lex_source(source);
assert_eq!(
tokens,
vec![
Tok::Bytes {
value: b"#a\x04S4".to_vec()
},
Tok::Newline
]
)
assert_eq!(tokens, vec![Tok::Bytes { value: b"#a\x04S4".to_vec() }, Tok::Newline])
}
#[test]
@ -1883,13 +1773,7 @@ class Foo(A, B):
let tokens = lex_source(source);
assert_eq!(
tokens,
vec![
Tok::String {
value: "\u{2002}".to_owned(),
is_fstring: false,
},
Tok::Newline
]
vec![Tok::String { value: "\u{2002}".to_owned(), is_fstring: false }, Tok::Newline]
)
}
}

View File

@ -31,5 +31,5 @@ lalrpop_mod!(
#[allow(unused)]
python
);
pub mod token;
pub mod config_comment_helper;
pub mod token;

View File

@ -75,9 +75,7 @@ pub fn parse(source: &str, mode: Mode, file: FileName) -> Result<ast::Mod, Parse
let marker_token = (Default::default(), mode.to_marker(), Default::default());
let tokenizer = iter::once(Ok(marker_token)).chain(lxr);
python::TopParser::new()
.parse(tokenizer)
.map_err(ParseError::from)
python::TopParser::new().parse(tokenizer).map_err(ParseError::from)
}
#[cfg(test)]
@ -163,7 +161,7 @@ class Foo(A, B):
let parse_ast = parse_expression(&source).unwrap();
insta::assert_debug_snapshot!(parse_ast);
}
#[test]
fn test_more_comment() {
let source = "\
@ -185,7 +183,7 @@ while i < 2: # nac3: 4
3";
insta::assert_debug_snapshot!(parse_program(source, Default::default()).unwrap());
}
#[test]
fn test_sample_comment() {
let source = "\

View File

@ -1,7 +1,7 @@
//! Different token definitions.
//! Loosely based on token.h from CPython source:
use std::fmt::{self, Write};
use crate::ast;
use std::fmt::{self, Write};
/// Python source code can be tokenized in a sequence of these tokens.
#[derive(Clone, Debug, PartialEq)]
@ -111,8 +111,16 @@ impl fmt::Display for Tok {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use Tok::*;
match self {
Name { name } => write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name)),
Int { value } => if *value != i128::MAX { write!(f, "'{}'", value) } else { write!(f, "'#OFL#'") },
Name { name } => {
write!(f, "'{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *name))
}
Int { value } => {
if *value != i128::MAX {
write!(f, "'{}'", value)
} else {
write!(f, "'#OFL#'")
}
}
Float { value } => write!(f, "'{}'", value),
Complex { real, imag } => write!(f, "{}j{}", real, imag),
String { value, is_fstring } => {
@ -134,7 +142,11 @@ impl fmt::Display for Tok {
}
f.write_str("\"")
}
ConfigComment { content } => write!(f, "ConfigComment: '{}'", ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)),
ConfigComment { content } => write!(
f,
"ConfigComment: '{}'",
ast::get_str_from_ref(&ast::get_str_ref_lock(), *content)
),
Newline => f.write_str("Newline"),
Indent => f.write_str("Indent"),
Dedent => f.write_str("Dedent"),

View File

@ -9,8 +9,8 @@ use nac3core::{
};
use nac3parser::ast::{self, StrRef};
use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, sync::Arc};
use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
pub struct ResolverInternal {
pub id_to_type: Mutex<HashMap<StrRef, Type>>,
@ -63,10 +63,12 @@ impl SymbolResolver for Resolver {
}
fn get_identifier_def(&self, id: StrRef) -> Result<DefinitionId, HashSet<String>> {
self.0.id_to_def.lock().get(&id).copied()
.ok_or_else(|| HashSet::from([
format!("Undefined identifier `{id}`"),
]))
self.0
.id_to_def
.lock()
.get(&id)
.copied()
.ok_or_else(|| HashSet::from([format!("Undefined identifier `{id}`")]))
}
fn get_string_id(&self, s: &str) -> i32 {

View File

@ -1,14 +1,11 @@
use clap::Parser;
use inkwell::{
memory_buffer::MemoryBuffer,
passes::PassBuilderOptions,
support::is_multithreaded,
targets::*,
memory_buffer::MemoryBuffer, passes::PassBuilderOptions, support::is_multithreaded, targets::*,
OptimizationLevel,
};
use parking_lot::{Mutex, RwLock};
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use std::collections::HashSet;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use nac3core::{
codegen::{
@ -18,7 +15,7 @@ use nac3core::{
symbol_resolver::SymbolResolver,
toplevel::{
composer::{ComposerConfig, TopLevelComposer},
helper::parse_parameter_default_value,
helper::parse_parameter_default_value,
type_annotation::*,
TopLevelDef,
},
@ -78,19 +75,18 @@ fn handle_typevar_definition(
primitives: &PrimitiveStore,
) -> Result<Type, HashSet<String>> {
let ExprKind::Call { func, args, .. } = &var.node else {
return Err(HashSet::from([
format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
),
]))
return Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)]));
};
match &func.node {
ExprKind::Name { id, .. } if id == &"TypeVar".into() => {
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([
format!("Expected string constant for first parameter of `TypeVar`, got {:?}", &args[0].node),
]))
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `TypeVar`, got {:?}",
&args[0].node
)]));
};
let generic_name: StrRef = ty_name.to_string().into();
@ -106,17 +102,15 @@ fn handle_typevar_definition(
x,
HashMap::default(),
)?;
get_type_from_type_annotation_kinds(
def_list, unifier, &ty, &mut None
)
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)
})
.collect::<Result<Vec<_>, _>>()?;
let loc = func.location;
if constraints.len() == 1 {
return Err(HashSet::from([
format!("A single constraint is not allowed (at {loc})"),
]))
return Err(HashSet::from([format!(
"A single constraint is not allowed (at {loc})"
)]));
}
Ok(unifier.get_fresh_var_with_range(&constraints, Some(generic_name), Some(loc)).0)
@ -124,18 +118,17 @@ fn handle_typevar_definition(
ExprKind::Name { id, .. } if id == &"ConstGeneric".into() => {
if args.len() != 2 {
return Err(HashSet::from([
format!("Expected 2 arguments for `ConstGeneric`, got {}", args.len()),
]))
return Err(HashSet::from([format!(
"Expected 2 arguments for `ConstGeneric`, got {}",
args.len()
)]));
}
let ExprKind::Constant { value: Constant::Str(ty_name), .. } = &args[0].node else {
return Err(HashSet::from([
format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
),
]))
return Err(HashSet::from([format!(
"Expected string constant for first parameter of `ConstGeneric`, got {:?}",
&args[0].node
)]));
};
let generic_name: StrRef = ty_name.to_string().into();
@ -147,19 +140,16 @@ fn handle_typevar_definition(
&args[1],
HashMap::default(),
)?;
let constraint = get_type_from_type_annotation_kinds(
def_list, unifier, &ty, &mut None
)?;
let constraint =
get_type_from_type_annotation_kinds(def_list, unifier, &ty, &mut None)?;
let loc = func.location;
Ok(unifier.get_fresh_const_generic_var(constraint, Some(generic_name), Some(loc)).0)
}
_ => Err(HashSet::from([
format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
),
]))
_ => Err(HashSet::from([format!(
"expression {var:?} cannot be handled as a generic parameter in global scope"
)])),
}
}
@ -175,18 +165,12 @@ fn handle_assignment_pattern(
if targets.len() == 1 {
match &targets[0].node {
ExprKind::Name { id, .. } => {
if let Ok(var) = handle_typevar_definition(
value,
resolver,
def_list,
unifier,
primitives,
) {
if let Ok(var) =
handle_typevar_definition(value, resolver, def_list, unifier, primitives)
{
internal_resolver.add_id_type(*id, var);
Ok(())
} else if let Ok(val) =
parse_parameter_default_value(value, resolver)
{
} else if let Ok(val) = parse_parameter_default_value(value, resolver) {
internal_resolver.add_module_global(*id, val);
Ok(())
} else {
@ -238,10 +222,7 @@ fn handle_assignment_pattern(
))
}
}
_ => Err(format!(
"unpack of this expression is not supported at {}",
value.location
)),
_ => Err(format!("unpack of this expression is not supported at {}", value.location)),
}
}
}
@ -250,15 +231,8 @@ fn main() {
const SIZE_T: u32 = usize::BITS;
let cli = CommandLineArgs::parse();
let CommandLineArgs {
file_name,
threads,
opt_level,
emit_llvm,
triple,
mcpu,
target_features,
} = cli;
let CommandLineArgs { file_name, threads, opt_level, emit_llvm, triple, mcpu, target_features } =
cli;
Target::initialize_all(&InitializationConfig::default());
@ -270,9 +244,7 @@ fn main() {
let target_features = target_features.unwrap_or_default();
let threads = if is_multithreaded() {
if threads == 0 {
std::thread::available_parallelism()
.map(|threads| threads.get() as u32)
.unwrap_or(1u32)
std::thread::available_parallelism().map(|threads| threads.get() as u32).unwrap_or(1u32)
} else {
threads
}
@ -308,7 +280,8 @@ fn main() {
class_names: Mutex::default(),
module_globals: Mutex::default(),
str_store: Mutex::default(),
}.into();
}
.into();
let resolver =
Arc::new(Resolver(internal_resolver.clone())) as Arc<dyn SymbolResolver + Send + Sync>;
@ -332,13 +305,19 @@ fn main() {
eprintln!("{err}");
return;
}
},
}
// allow (and ignore) "from __future__ import annotations"
StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into()) && names.len() == 1 && names[0].name == "annotations".into() => (),
if module == &Some("__future__".into())
&& names.len() == 1
&& names[0].name == "annotations".into() =>
{
()
}
_ => {
let (name, def_id, ty) =
composer.register_top_level(stmt, Some(resolver.clone()), "__main__", true).unwrap();
let (name, def_id, ty) = composer
.register_top_level(stmt, Some(resolver.clone()), "__main__", true)
.unwrap();
internal_resolver.add_id_def(name, def_id);
if let Some(ty) = ty {
internal_resolver.add_id_type(name, ty);
@ -364,7 +343,8 @@ fn main() {
.unwrap_or_else(|_| panic!("cannot find run() entry point"))
.0]
.write();
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else {
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance
else {
unreachable!()
};
instance_to_symbol.insert(String::new(), "run".to_string());
@ -444,7 +424,8 @@ fn main() {
function_iter = func.get_next_function();
}
let target_machine = llvm_options.target
let target_machine = llvm_options
.target
.create_target_machine(llvm_options.opt_level)
.expect("couldn't create target machine");

View File

@ -47,12 +47,11 @@ pub extern "C" fn __nac3_personality(_state: u32, _exception_object: u32, _conte
unimplemented!();
}
fn main() {
let filename = env::args().nth(1).unwrap();
unsafe {
let lib = libloading::Library::new(filename).unwrap();
let func: libloading::Symbol<unsafe extern fn()> = lib.get(b"__modinit__").unwrap();
let func: libloading::Symbol<unsafe extern "C" fn()> = lib.get(b"__modinit__").unwrap();
func()
}
}