forked from M-Labs/nac3
1
0
Fork 0

nac3artiq: get_obj_value take an additional argument for expected type

This commit is contained in:
ychenfo 2022-04-10 01:02:52 +08:00
parent 0e0871bc38
commit 089bba96a3
2 changed files with 68 additions and 46 deletions

View File

@ -517,7 +517,7 @@ pub fn attributes_writeback<'ctx, 'a>(
// we only care about primitive attributes // we only care about primitive attributes
// for non-primitive attributes, they should be in another global // for non-primitive attributes, they should be in another global
let mut attributes = Vec::new(); let mut attributes = Vec::new();
let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap(); let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap();
for (name, (field_ty, is_mutable)) in fields.iter() { for (name, (field_ty, is_mutable)) in fields.iter() {
if !is_mutable { if !is_mutable {
continue continue
@ -542,7 +542,7 @@ pub fn attributes_writeback<'ctx, 'a>(
let pydict = PyDict::new(py); let pydict = PyDict::new(py);
pydict.set_item("obj", val)?; pydict.set_item("obj", val)?;
host_attributes.append(pydict)?; host_attributes.append(pydict)?;
values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap())); values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap()));
} }
}, },
_ => {} _ => {}

View File

@ -131,7 +131,7 @@ impl StaticValue for PythonValue {
&self, &self,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
_expected_ty: Type, expected_ty: Type,
) -> Result<BasicValueEnum<'ctx>, String> { ) -> Result<BasicValueEnum<'ctx>, String> {
if let Some(val) = self.resolver.id_to_primitive.read().get(&self.id) { if let Some(val) = self.resolver.id_to_primitive.read().get(&self.id) {
return Ok(match val { return Ok(match val {
@ -151,7 +151,7 @@ impl StaticValue for PythonValue {
Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> { Python::with_gil(|py| -> PyResult<BasicValueEnum<'ctx>> {
self.resolver self.resolver
.get_obj_value(py, self.value.as_ref(py), ctx, generator) .get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty)
.map(Option::unwrap) .map(Option::unwrap)
}).map_err(|e| e.to_string()) }).map_err(|e| e.to_string())
} }
@ -693,6 +693,7 @@ impl InnerResolver {
obj: &PyAny, obj: &PyAny,
ctx: &mut CodeGenContext<'ctx, 'a>, ctx: &mut CodeGenContext<'ctx, 'a>,
generator: &mut dyn CodeGenerator, generator: &mut dyn CodeGenerator,
expected_ty: Type,
) -> PyResult<Option<BasicValueEnum<'ctx>>> { ) -> PyResult<Option<BasicValueEnum<'ctx>>> {
let ty_id: u64 = let ty_id: u64 =
self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?;
@ -729,28 +730,25 @@ impl InnerResolver {
Ok(Some(ctx.ctx.f64_type().const_float(val).into())) Ok(Some(ctx.ctx.f64_type().const_float(val).into()))
} else if ty_id == self.primitive_ids.list { } else if ty_id == self.primitive_ids.list {
let id_str = id.to_string(); let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) { if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} }
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
let ty = if len == 0 { let elem_ty =
ctx.primitives.int32 if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref()
{
*ty
} else { } else {
self.get_list_elem_type( unreachable!("must be list")
py,
obj,
len,
&mut ctx.unifier,
&ctx.top_level.definitions.read(),
&ctx.primitives,
)?
.unwrap()
}; };
let ty = ctx.get_llvm_type(generator, ty); let ty = ctx.get_llvm_type(generator, elem_ty);
let size_t = generator.get_size_type(ctx.ctx); let size_t = generator.get_size_type(ctx.ctx);
let arr_ty = ctx let arr_ty = ctx
.ctx .ctx
.struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false); .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false);
{ {
if self.global_value_ids.read().contains_key(&id) { if self.global_value_ids.read().contains_key(&id) {
let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| {
@ -761,13 +759,20 @@ impl InnerResolver {
self.global_value_ids.write().insert(id, obj.into()); self.global_value_ids.write().insert(id, obj.into());
} }
} }
let arr: Result<Option<Vec<_>>, _> = (0..len) let arr: Result<Option<Vec<_>>, _> = (0..len)
.map(|i| { .map(|i| {
obj.get_item(i).and_then(|elem| self.get_obj_value(py, elem, ctx, generator).map_err( obj
|e| super::CompileError::new_err(format!("Error getting element {}: {}", i, e)))) .get_item(i)
.and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty)
.map_err(
|e| super::CompileError::new_err(
format!("Error getting element {}: {}", i, e))
))
}) })
.collect(); .collect();
let arr = arr?.unwrap(); let arr = arr?.unwrap();
let arr_global = ctx.module.add_global( let arr_global = ctx.module.add_global(
ty.array_type(len as u32), ty.array_type(len as u32),
Some(AddressSpace::Generic), Some(AddressSpace::Generic),
@ -793,29 +798,59 @@ impl InnerResolver {
} }
.into(); .into();
arr_global.set_initializer(&arr); arr_global.set_initializer(&arr);
let val = arr_ty.const_named_struct(&[ let val = arr_ty.const_named_struct(&[
arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(), arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(),
size_t.const_int(len as u64, false).into(), size_t.const_int(len as u64, false).into(),
]); ]);
let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str); let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str);
global.set_initializer(&val); global.set_initializer(&val);
Ok(Some(global.as_pointer_value().into())) Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple { } else if ty_id == self.primitive_ids.tuple {
if let TypeEnum::TTuple { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
let tup_tys = ty.iter();
let elements: &PyTuple = obj.cast_as()?; let elements: &PyTuple = obj.cast_as()?;
assert_eq!(elements.len(), tup_tys.len());
let val: Result<Option<Vec<_>>, _> = let val: Result<Option<Vec<_>>, _> =
elements.iter().enumerate().map(|(i, elem)| self.get_obj_value(py, elem, ctx, generator).map_err(|e| elements
super::CompileError::new_err(format!("Error getting element {}: {}", i, e)))).collect(); .iter()
.enumerate()
.zip(tup_tys)
.map(|((i, elem), ty)| self
.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e|
super::CompileError::new_err(
format!("Error getting element {}: {}", i, e)
)
)
).collect();
let val = val?.unwrap(); let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false); let val = ctx.ctx.const_struct(&val, false);
Ok(Some(val.into())) Ok(Some(val.into()))
} else {
unreachable!("must expect tuple type")
}
} else if ty_id == self.primitive_ids.option { } else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) =>
{
*params.iter().next().unwrap().1
}
_ => unreachable!("must be option type")
};
if id == self.primitive_ids.none { if id == self.primitive_ids.none {
// for option type, just a null ptr, whose type needs to be casted in codegen // for option type, just a null ptr
// according to the type info attached in the ast Ok(Some(
Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into())) ctx.get_llvm_type(generator, option_val_ty)
.ptr_type(AddressSpace::Generic)
.const_null()
.into(),
))
} else { } else {
match self match self
.get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator) .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty)
.map_err(|e| { .map_err(|e| {
super::CompileError::new_err(format!( super::CompileError::new_err(format!(
"Error getting value of Option object: {}", "Error getting value of Option object: {}",
@ -843,9 +878,11 @@ impl InnerResolver {
} }
} else { } else {
let id_str = id.to_string(); let id_str = id.to_string();
if let Some(global) = ctx.module.get_global(&id_str) { if let Some(global) = ctx.module.get_global(&id_str) {
return Ok(Some(global.as_pointer_value().into())); return Ok(Some(global.as_pointer_value().into()));
} }
let top_level_defs = ctx.top_level.definitions.read(); let top_level_defs = ctx.top_level.definitions.read();
let ty = self let ty = self
.get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)?
@ -872,23 +909,8 @@ impl InnerResolver {
let values: Result<Option<Vec<_>>, _> = fields let values: Result<Option<Vec<_>>, _> = fields
.iter() .iter()
.map(|(name, ty, _)| { .map(|(name, ty, _)| {
let v = self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator) self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator, *ty)
.map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e))); .map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e)))
match (v, ctx.unifier.get_ty_immutable(*ty).as_ref()) {
(Ok(Some(v)), TypeEnum::TObj { obj_id, params, .. })
if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) =>
{
let actual_ptr_ty = ctx
.get_llvm_type(generator, *params.iter().next().unwrap().1)
.ptr_type(AddressSpace::Generic);
Ok(Some(ctx.builder.build_bitcast(
v,
actual_ptr_ty,
"option_none_ptr_cast",
)))
}
(v, _) => v,
}
}) })
.collect(); .collect();
let values = values?; let values = values?;