forked from M-Labs/nac3
1
0
Fork 0

nac3artiq: get_obj_value take an additional argument for expected type

This commit is contained in:
ychenfo 2022-04-10 01:02:52 +08:00
parent 0e0871bc38
commit 089bba96a3
2 changed files with 68 additions and 46 deletions

View File

@ -517,7 +517,7 @@ pub fn attributes_writeback<'ctx, 'a>(
// we only care about primitive attributes
// for non-primitive attributes, they should be in another global
let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields.iter() {
if !is_mutable {
continue
@ -542,7 +542,7 @@ pub fn attributes_writeback<'ctx, 'a>(
let pydict = PyDict::new(py);
pydict.set_item("obj", val)?;
host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap()));
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
}
},
_ => {}

View File

@ -131,7 +131,7 @@ impl StaticValue for PythonValue {
&self,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
_expected_ty: Type,
expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> {
if let Some(val) = self.resolver.id_to_primitive.read().get(&self.id) {
return Ok(match val {
@ -151,7 +151,7 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator)
.get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.map(Option::unwrap)
}).map_err(|e| e.to_string())
}
@ -693,6 +693,7 @@ impl InnerResolver {
obj: &PyAny,
ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> PyResult<Option<BasicValueEnum<'ctx>>> {
let ty_id: u64 =
self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?;
@ -729,28 +730,25 @@ impl InnerResolver {
Ok(Some(ctx.ctx.f64_type().const_float(val).into()))
} else if ty_id == self.primitive_ids.list {
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let ty = if len == 0 {
ctx.primitives.int32
let elem_ty =
if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref()
{
*ty
} else {
self.get_list_elem_type(
py,
obj,
len,
&mut ctx.unifier,
&ctx.top_level.definitions.read(),
&ctx.primitives,
)?
.unwrap()
unreachable!("must be list")
};
let ty = ctx.get_llvm_type(generator, ty);
let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx);
let arr_ty = ctx
.ctx
.struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false);
{
if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
@ -761,13 +759,20 @@ impl InnerResolver {
self.global_value_ids.write().insert(id, obj.into());
}
}
let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| {
obj.get_item(i).and_then(|elem| self.get_obj_value(py, elem, ctx, generator).map_err(
|e| super::CompileError::new_err(format!("Error getting element {}: {}", i, e))))
obj
.get_item(i)
.and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty)
.map_err(
|e| super::CompileError::new_err(
format!("Error getting element {}: {}", i, e))
))
})
.collect();
let arr = arr?.unwrap();
let arr_global = ctx.module.add_global(
ty.array_type(len as u32),
Some(AddressSpace::Generic),
@ -793,29 +798,59 @@ impl InnerResolver {
}
.into();
arr_global.set_initializer(&arr);
let val = arr_ty.const_named_struct(&[
arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(),
size_t.const_int(len as u64, false).into(),
]);
let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
if let TypeEnum::TTuple { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
let tup_tys = ty.iter();
let elements: &PyTuple = obj.cast_as()?;
assert_eq!(elements.len(), tup_tys.len());
let val: Result<Option<Vec<_>>, _> =
elements.iter().enumerate().map(|(i, elem)| self.get_obj_value(py, elem, ctx, generator).map_err(|e|
super::CompileError::new_err(format!("Error getting element {}: {}", i, e)))).collect();
elements
.iter()
.enumerate()
.zip(tup_tys)
.map(|((i, elem), ty)| self
.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e|
super::CompileError::new_err(
format!("Error getting element {}: {}", i, e)
)
)
).collect();
let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false);
Ok(Some(val.into()))
} else {
unreachable!("must expect tuple type")
}
} else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type")
};
if id == self.primitive_ids.none {
// for option type, just a null ptr, whose type needs to be casted in codegen
// according to the type info attached in the ast
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into()))
// for option type, just a null ptr
Ok(Some(
ctx.get_llvm_type(generator, option_val_ty)
.ptr_type(AddressSpace::Generic)
.const_null()
.into(),
))
} else {
match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator)
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty)
.map_err(|e| {
super::CompileError::new_err(format!(
"Error getting value of Option object: {}",
@ -843,9 +878,11 @@ impl InnerResolver {
}
} else {
let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into()));
}
let top_level_defs = ctx.top_level.definitions.read();
let ty = self
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
@ -872,23 +909,8 @@ impl InnerResolver {
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty, _)| {
let v = self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator)
.map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e)));
match (v, ctx.unifier.get_ty_immutable(*ty).as_ref()) {
(Ok(Some(v)), TypeEnum::TObj { obj_id, params, .. })
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) =>
{
let actual_ptr_ty = ctx
.get_llvm_type(generator, *params.iter().next().unwrap().1)
.ptr_type(AddressSpace::Generic);
Ok(Some(ctx.builder.build_bitcast(
v,
actual_ptr_ty,
"option_none_ptr_cast",
)))
}
(v, _) => v,
}
self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator, *ty)
.map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e)))
})
.collect();
let values = values?;