From 089bba96a35cb0e3ee2ffb65ba0399275c9f083f Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sun, 10 Apr 2022 01:02:52 +0800 Subject: [PATCH] nac3artiq: get_obj_value take an additional argument for expected type --- nac3artiq/src/codegen.rs | 4 +- nac3artiq/src/symbol_resolver.rs | 110 ++++++++++++++++++------------- 2 files changed, 68 insertions(+), 46 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 47b3bc1c..b63ee60c 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -517,7 +517,7 @@ pub fn attributes_writeback<'ctx, 'a>( // we only care about primitive attributes // for non-primitive attributes, they should be in another global let mut attributes = Vec::new(); - let obj = inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap(); + let obj = inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap(); for (name, (field_ty, is_mutable)) in fields.iter() { if !is_mutable { continue @@ -542,7 +542,7 @@ pub fn attributes_writeback<'ctx, 'a>( let pydict = PyDict::new(py); pydict.set_item("obj", val)?; host_attributes.append(pydict)?; - values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator)?.unwrap())); + values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); } }, _ => {} diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 284f06b3..2177bf8a 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -131,7 +131,7 @@ impl StaticValue for PythonValue { &self, ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut dyn CodeGenerator, - _expected_ty: Type, + expected_ty: Type, ) -> Result, String> { if let Some(val) = self.resolver.id_to_primitive.read().get(&self.id) { return Ok(match val { @@ -151,7 +151,7 @@ impl StaticValue for PythonValue { Python::with_gil(|py| -> PyResult> { self.resolver - .get_obj_value(py, self.value.as_ref(py), ctx, generator) + .get_obj_value(py, self.value.as_ref(py), ctx, generator, expected_ty) .map(Option::unwrap) }).map_err(|e| e.to_string()) } @@ -693,6 +693,7 @@ impl InnerResolver { obj: &PyAny, ctx: &mut CodeGenContext<'ctx, 'a>, generator: &mut dyn CodeGenerator, + expected_ty: Type, ) -> PyResult>> { let ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; @@ -729,28 +730,25 @@ impl InnerResolver { Ok(Some(ctx.ctx.f64_type().const_float(val).into())) } else if ty_id == self.primitive_ids.list { let id_str = id.to_string(); + if let Some(global) = ctx.module.get_global(&id_str) { return Ok(Some(global.as_pointer_value().into())); } + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; - let ty = if len == 0 { - ctx.primitives.int32 + let elem_ty = + if let TypeEnum::TList { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() + { + *ty } else { - self.get_list_elem_type( - py, - obj, - len, - &mut ctx.unifier, - &ctx.top_level.definitions.read(), - &ctx.primitives, - )? - .unwrap() + unreachable!("must be list") }; - let ty = ctx.get_llvm_type(generator, ty); + let ty = ctx.get_llvm_type(generator, elem_ty); let size_t = generator.get_size_type(ctx.ctx); let arr_ty = ctx .ctx .struct_type(&[ty.ptr_type(AddressSpace::Generic).into(), size_t.into()], false); + { if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { @@ -761,13 +759,20 @@ impl InnerResolver { self.global_value_ids.write().insert(id, obj.into()); } } + let arr: Result>, _> = (0..len) .map(|i| { - obj.get_item(i).and_then(|elem| self.get_obj_value(py, elem, ctx, generator).map_err( - |e| super::CompileError::new_err(format!("Error getting element {}: {}", i, e)))) + obj + .get_item(i) + .and_then(|elem| self.get_obj_value(py, elem, ctx, generator, elem_ty) + .map_err( + |e| super::CompileError::new_err( + format!("Error getting element {}: {}", i, e)) + )) }) .collect(); let arr = arr?.unwrap(); + let arr_global = ctx.module.add_global( ty.array_type(len as u32), Some(AddressSpace::Generic), @@ -793,29 +798,59 @@ impl InnerResolver { } .into(); arr_global.set_initializer(&arr); + let val = arr_ty.const_named_struct(&[ arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::Generic)).into(), size_t.const_int(len as u64, false).into(), ]); + let global = ctx.module.add_global(arr_ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.cast_as()?; - let val: Result>, _> = - elements.iter().enumerate().map(|(i, elem)| self.get_obj_value(py, elem, ctx, generator).map_err(|e| - super::CompileError::new_err(format!("Error getting element {}: {}", i, e)))).collect(); - let val = val?.unwrap(); - let val = ctx.ctx.const_struct(&val, false); - Ok(Some(val.into())) + if let TypeEnum::TTuple { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() { + let tup_tys = ty.iter(); + let elements: &PyTuple = obj.cast_as()?; + assert_eq!(elements.len(), tup_tys.len()); + let val: Result>, _> = + elements + .iter() + .enumerate() + .zip(tup_tys) + .map(|((i, elem), ty)| self + .get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| + super::CompileError::new_err( + format!("Error getting element {}: {}", i, e) + ) + ) + ).collect(); + let val = val?.unwrap(); + let val = ctx.ctx.const_struct(&val, false); + Ok(Some(val.into())) + } else { + unreachable!("must expect tuple type") + } } else if ty_id == self.primitive_ids.option { + let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => + { + *params.iter().next().unwrap().1 + } + _ => unreachable!("must be option type") + }; if id == self.primitive_ids.none { - // for option type, just a null ptr, whose type needs to be casted in codegen - // according to the type info attached in the ast - Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into())) + // for option type, just a null ptr + Ok(Some( + ctx.get_llvm_type(generator, option_val_ty) + .ptr_type(AddressSpace::Generic) + .const_null() + .into(), + )) } else { match self - .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator) + .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator, option_val_ty) .map_err(|e| { super::CompileError::new_err(format!( "Error getting value of Option object: {}", @@ -843,9 +878,11 @@ impl InnerResolver { } } else { let id_str = id.to_string(); + if let Some(global) = ctx.module.get_global(&id_str) { return Ok(Some(global.as_pointer_value().into())); } + let top_level_defs = ctx.top_level.definitions.read(); let ty = self .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? @@ -872,23 +909,8 @@ impl InnerResolver { let values: Result>, _> = fields .iter() .map(|(name, ty, _)| { - let v = self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator) - .map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e))); - match (v, ctx.unifier.get_ty_immutable(*ty).as_ref()) { - (Ok(Some(v)), TypeEnum::TObj { obj_id, params, .. }) - if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => - { - let actual_ptr_ty = ctx - .get_llvm_type(generator, *params.iter().next().unwrap().1) - .ptr_type(AddressSpace::Generic); - Ok(Some(ctx.builder.build_bitcast( - v, - actual_ptr_ty, - "option_none_ptr_cast", - ))) - } - (v, _) => v, - } + self.get_obj_value(py, obj.getattr(&name.to_string())?, ctx, generator, *ty) + .map_err(|e| super::CompileError::new_err(format!("Error getting field {}: {}", name, e))) }) .collect(); let values = values?;