LLVM Call Parameter Mismatch Fix #293

Closed
ychenfo wants to merge 3 commits from fix-call-param-type into master
5 changed files with 168 additions and 49 deletions

View File

@ -165,6 +165,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
) -> BasicTypeEnum<'ctx> { ) -> BasicTypeEnum<'ctx> {
get_llvm_type( get_llvm_type(
self.ctx, self.ctx,
&self.module,
generator, generator,
&mut self.unifier, &mut self.unifier,
self.top_level, self.top_level,
@ -361,17 +362,31 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
} }
} }
} }
let params = if loc_params.is_empty() { let params = if loc_params.is_empty() { params } else { &loc_params };
params let params = fun
} else { .get_type()
&loc_params .get_param_types()
}; .into_iter()
.zip(params.iter())
.map(|(ty, val)| match (ty, val.get_type()) {
(BasicTypeEnum::PointerType(arg_ty), BasicTypeEnum::PointerType(val_ty))
if {
ty != val.get_type()
&& arg_ty.get_element_type().is_struct_type()
&& val_ty.get_element_type().is_struct_type()
} =>
{
self.builder.build_bitcast(*val, arg_ty, "call_arg_cast")
}
_ => *val,
})
.collect_vec();
let result = if let Some(target) = self.unwind_target { let result = if let Some(target) = self.unwind_target {
let current = self.builder.get_insert_block().unwrap().get_parent().unwrap(); let current = self.builder.get_insert_block().unwrap().get_parent().unwrap();
let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name)); let then_block = self.ctx.append_basic_block(current, &format!("after.{}", call_name));
let result = self let result = self
.builder .builder
.build_invoke(fun, params, then_block, target, call_name) .build_invoke(fun, &params, then_block, target, call_name)
.try_as_basic_value() .try_as_basic_value()
.left(); .left();
self.builder.position_at_end(then_block); self.builder.position_at_end(then_block);

View File

@ -205,7 +205,7 @@ impl WorkerRegistry {
fn worker_thread<G: CodeGenerator>(&self, generator: &mut G, f: Arc<WithCall>) { fn worker_thread<G: CodeGenerator>(&self, generator: &mut G, f: Arc<WithCall>) {
let context = Context::create(); let context = Context::create();
let mut builder = context.create_builder(); let mut builder = context.create_builder();
let module = context.create_module(generator.get_name()); let mut module = context.create_module(generator.get_name());
module.add_basic_value_flag( module.add_basic_value_flag(
"Debug Info Version", "Debug Info Version",
@ -225,16 +225,17 @@ impl WorkerRegistry {
let mut errors = HashSet::new(); let mut errors = HashSet::new();
while let Some(task) = self.receiver.recv().unwrap() { while let Some(task) = self.receiver.recv().unwrap() {
let tmp_module = context.create_module("tmp"); let prev_module = module.write_bitcode_to_memory();
match gen_func(&context, generator, self, builder, tmp_module, task) { match gen_func(&context, generator, self, builder, module, task) {
Ok(result) => { Ok(result) => {
builder = result.0; builder = result.0;
passes.run_on(&result.2); passes.run_on(&result.2);
module.link_in_module(result.1).unwrap(); module = result.1;
} }
Err((old_builder, e)) => { Err((old_builder, e)) => {
builder = old_builder; builder = old_builder;
errors.insert(e); errors.insert(e);
module = context.create_module_from_ir(prev_module).unwrap();
Review

What's the impact on performance of this?
It sounds strange that we would have to serialize and deserialize things for passing objects around in the same program.

What's the impact on performance of this? It sounds strange that we would have to serialize and deserialize things for passing objects around in the same program.
Review

Thanks for pointing this out! I just run some benchmark and this is indeed significantly slowing things down, as shown below. I will look into this.

python porgram from 10 to 150 functions (50 loc per function); compile time (second);

10 20 30 40 50 60 70 80 90 100 110 120 130 140 150
before 0.67500 0.82808 0.95160 1.04178 1.12146 1.2736 1.3426 1.43848 1.52762 1.66717 1.8724 1.9046 1.9810 2.1185 2.1865
after 1.00600 1.2652 1.48977 1.6315 1.78203 2.1227 2.3221 2.5354 2.6856 3.0505 3.5356 3.6710 4.0322 4.1374 4.3924
Thanks for pointing this out! I just run some benchmark and this is indeed significantly slowing things down, as shown below. I will look into this. python porgram from 10 to 150 functions (50 loc per function); compile time (second); | | 10 | 20 | 30 | 40 | 50 | 60 | 70 | 80 | 90 | 100 | 110 | 120 | 130 | 140 | 150 | | ------ | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | | before |0.67500 | 0.82808 | 0.95160 | 1.04178 | 1.12146 | 1.2736 | 1.3426 | 1.43848 | 1.52762 | 1.66717 | 1.8724 | 1.9046 | 1.9810 | 2.1185 | 2.1865 | | after |1.00600 | 1.2652 | 1.48977 | 1.6315 | 1.78203 | 2.1227 | 2.3221 | 2.5354 | 2.6856 | 3.0505 | 3.5356 | 3.6710 | 4.0322 | 4.1374 | 4.3924 |
Review

Ah actually there seems no need to precisely restore the previous llvm module, creating a new empty one just to collect errors should be enough.

Ah actually there seems no need to precisely restore the previous llvm module, creating a new empty one just to collect errors should be enough.
} }
} }
*self.task_count.lock() -= 1; *self.task_count.lock() -= 1;
@ -271,6 +272,7 @@ pub struct CodeGenTask {
fn get_llvm_type<'ctx>( fn get_llvm_type<'ctx>(
ctx: &'ctx Context, ctx: &'ctx Context,
module: &Module<'ctx>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
unifier: &mut Unifier, unifier: &mut Unifier,
top_level: &TopLevelContext, top_level: &TopLevelContext,
@ -295,6 +297,7 @@ fn get_llvm_type<'ctx>(
) if *obj_id == *opt_id => { ) if *obj_id == *opt_id => {
return get_llvm_type( return get_llvm_type(
ctx, ctx,
module,
generator, generator,
unifier, unifier,
top_level, top_level,
@ -311,27 +314,37 @@ fn get_llvm_type<'ctx>(
// a struct with fields in the order of declaration // a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read(); let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(obj_id.0).unwrap(); let definition = top_level_defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { name, fields: fields_list, .. } = let ty = if let TopLevelDef::Class { fields: fields_list, .. } =
&*definition.read() &*definition.read()
{ {
let struct_type = ctx.opaque_struct_type(&name.to_string()); let name = unifier.stringify(ty);
type_cache.insert(unifier.get_representative(ty), struct_type.ptr_type(AddressSpace::Generic).into()); match module.get_struct_type(&name) {
let fields = fields_list Some(t) => t.ptr_type(AddressSpace::Generic).into(),
.iter() None => {
.map(|f| { let struct_type = ctx.opaque_struct_type(&name);
get_llvm_type( type_cache.insert(
ctx, unifier.get_representative(ty),
generator, struct_type.ptr_type(AddressSpace::Generic).into()
unifier, );
top_level, let fields = fields_list
type_cache, .iter()
primitives, .map(|f| {
fields[&f.0].0, get_llvm_type(
) ctx,
}) module,
.collect_vec(); generator,
struct_type.set_body(&fields, false); unifier,
struct_type.ptr_type(AddressSpace::Generic).into() top_level,
type_cache,
primitives,
fields[&f.0].0,
)
})
.collect_vec();
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::Generic).into()
}
}
} else { } else {
unreachable!() unreachable!()
}; };
@ -341,14 +354,19 @@ fn get_llvm_type<'ctx>(
// a struct with fields in the order present in the tuple // a struct with fields in the order present in the tuple
let fields = ty let fields = ty
.iter() .iter()
.map(|ty| get_llvm_type(ctx, generator, unifier, top_level, type_cache, primitives, *ty)) .map(|ty| {
get_llvm_type(
ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
)
})
.collect_vec(); .collect_vec();
ctx.struct_type(&fields, false).into() ctx.struct_type(&fields, false).into()
} }
TList { ty } => { TList { ty } => {
// a struct with an integer and a pointer to an array // a struct with an integer and a pointer to an array
let element_type = let element_type = get_llvm_type(
get_llvm_type(ctx, generator, unifier, top_level, type_cache, primitives, *ty); ctx, module, generator, unifier, top_level, type_cache, primitives, *ty,
);
let fields = [ let fields = [
element_type.ptr_type(AddressSpace::Generic).into(), element_type.ptr_type(AddressSpace::Generic).into(),
generator.get_size_type(ctx).into(), generator.get_size_type(ctx).into(),
@ -434,28 +452,40 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
(primitives.float, context.f64_type().into()), (primitives.float, context.f64_type().into()),
(primitives.bool, context.bool_type().into()), (primitives.bool, context.bool_type().into()),
(primitives.str, { (primitives.str, {
let str_type = context.opaque_struct_type("str"); let name = "str";
let fields = [ match module.get_struct_type(name) {
context.i8_type().ptr_type(AddressSpace::Generic).into(), None => {
generator.get_size_type(context).into(), let str_type = context.opaque_struct_type("str");
]; let fields = [
str_type.set_body(&fields, false); context.i8_type().ptr_type(AddressSpace::Generic).into(),
str_type.into() generator.get_size_type(context).into(),
];
str_type.set_body(&fields, false);
str_type.into()
}
Some(t) => t.as_basic_type_enum()
}
}), }),
(primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into()), (primitives.range, context.i32_type().array_type(3).ptr_type(AddressSpace::Generic).into()),
(primitives.exception, {
let name = "Exception";
match module.get_struct_type(name) {
Some(t) => t.ptr_type(AddressSpace::Generic).as_basic_type_enum(),
None => {
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = module.get_struct_type("str").unwrap().as_basic_type_enum();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::Generic).as_basic_type_enum()
}
}
})
] ]
.iter() .iter()
.cloned() .cloned()
.collect(); .collect();
type_cache.insert(primitives.exception, {
let exception = context.opaque_struct_type("Exception");
let int32 = context.i32_type().into();
let int64 = context.i64_type().into();
let str_ty = *type_cache.get(&primitives.str).unwrap();
let fields = [int32, str_ty, int32, int32, str_ty, str_ty, int64, int64, int64];
exception.set_body(&fields, false);
exception.ptr_type(AddressSpace::Generic).into()
});
// NOTE: special handling of option cannot use this type cache since it contains type var, // NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead // handled inside get_llvm_type instead
@ -478,7 +508,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let ret_type = if unifier.unioned(ret, primitives.none) { let ret_type = if unifier.unioned(ret, primitives.none) {
None None
} else { } else {
Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) Some(get_llvm_type(context, &module, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret))
}; };
let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty)); let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty));
@ -487,6 +517,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
.map(|arg| { .map(|arg| {
get_llvm_type( get_llvm_type(
context, context,
&module,
generator, generator,
&mut unifier, &mut unifier,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),
@ -535,6 +566,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
let alloca = builder.build_alloca( let alloca = builder.build_alloca(
get_llvm_type( get_llvm_type(
context, context,
&module,
generator, generator,
&mut unifier, &mut unifier,
top_level_ctx.as_ref(), top_level_ctx.as_ref(),

View File

@ -0,0 +1,32 @@
from __future__ import annotations
@extern
def output_int32(x: int32):
...
class A:
a: int32
def __init__(self, a: int32):
self.a = a
def f1(self):
self.f2()
def f2(self):
output_int32(self.a)
class B(A):
b: int32
def __init__(self, b: int32):
self.a = b + 1
self.b = b
def run() -> int32:
aaa = A(5)
bbb = B(2)
aaa.f1()
bbb.f1()
return 0

View File

@ -0,0 +1,36 @@
from __future__ import annotations
@extern
def output_int32(a: int32):
...
class A:
d: int32
a: list[B]
def __init__(self, b: list[B]):
self.d = 123
self.a = b
def f(self):
output_int32(self.d)
class B:
a: A
def __init__(self, a: A):
self.a = a
def ff(self):
self.a.f()
class Demo:
a: A
def __init__(self, a: A):
self.a = a
def run() -> int32:
aa = A([])
bb = B(aa)
aa.a = [bb]
d = Demo(aa)
d.a.a[0].ff()
return 0

View File

@ -34,5 +34,9 @@ def run() -> int32:
insta = A() insta = A()
inst = C(insta) inst = C(insta)
inst.foo() inst.foo()
insta2 = B()
inst2 = C(insta2)
inst2.foo()
return 0 return 0