diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 63eb29d..8120ea2 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -950,549 +950,534 @@ impl InnerResolver { |_| Ok(Ok(extracted_ty)), ) } else if let Ok(s) = obj.extract::() { - if unifier.unioned(extracted_ty, primitives.str) { - Ok(Ok(primitives.str)) - } else { - Ok(Err(format!("expected str, got {s}"))) - } + if unifier.unioned(extracted_ty, primitives.str) { + Ok(Ok(primitives.str)) + } else { + Ok(Err(format!("expected str, got {s}"))) + } } else { Ok(Ok(extracted_ty)) } - } } } } +} - pub fn get_obj_value<'ctx>( - &self, - py: Python, - obj: &PyAny, - ctx: &mut CodeGenContext<'ctx, '_>, - generator: &mut dyn CodeGenerator, - expected_ty: Type, - ) -> PyResult>> { - let ty_id: u64 = - self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; - let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; - if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - let val: i32 = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::I32(val)); - Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) - } else if ty_id == self.primitive_ids.int64 { - let val: i64 = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::I64(val)); - Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) - } else if ty_id == self.primitive_ids.uint32 { - let val: u32 = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::U32(val)); - Ok(Some(ctx.ctx.i32_type().const_int(u64::from(val), false).into())) - } else if ty_id == self.primitive_ids.uint64 { - let val: u64 = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); - Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) - } else if ty_id == self.primitive_ids.bool { - let val: bool = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); - Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) - } else if ty_id == self.primitive_ids.np_bool_ { - let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); - Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) - } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { - let val: String = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); - return Ok(Some(ctx.gen_string(generator, val).into())); - } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { - let val: f64 = obj.extract().unwrap(); - self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); - Ok(Some(ctx.ctx.f64_type().const_float(val).into())) - } else if ty_id == self.primitive_ids.list { - let id_str = id.to_string(); +pub fn get_obj_value<'ctx>( + &self, + py: Python, + obj: &PyAny, + ctx: &mut CodeGenContext<'ctx, '_>, + generator: &mut dyn CodeGenerator, + expected_ty: Type, +) -> PyResult>> { + let ty_id: u64 = + self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; + if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::I32(val)); + Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::I64(val)); + Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.uint32 { + let val: u32 = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::U32(val)); + Ok(Some(ctx.ctx.i32_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.uint64 { + let val: u64 = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); + Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Bool(val)); + Ok(Some(ctx.ctx.i8_type().const_int(u64::from(val), false).into())) + } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { + let val: String = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::Str(val.clone())); + return Ok(Some(ctx.gen_string(generator, val).into())); + } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { + let val: f64 = obj.extract().unwrap(); + self.id_to_primitive.write().insert(id, PrimitiveValue::F64(val)); + Ok(Some(ctx.ctx.f64_type().const_float(val).into())) + } else if ty_id == self.primitive_ids.list { + let id_str = id.to_string(); - if let Some(global) = ctx.module.get_global(&id_str) { + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; + let elem_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { + TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { + iter_type_vars(params).nth(0).unwrap().ty + } + _ => unreachable!("must be list"), + }; + let size_t = generator.get_size_type(ctx.ctx); + let ty = if len == 0 + && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) + { + // The default type for zero-length lists of unknown element type is size_t + size_t.into() + } else { + ctx.get_llvm_type(generator, elem_ty) + }; + let arr_ty = ctx + .ctx + .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(arr_ty, Some(AddressSpace::default()), &id_str) + }); return Ok(Some(global.as_pointer_value().into())); } + self.global_value_ids.write().insert(id, obj.into()); + } - let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; - let elem_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => { - iter_type_vars(params).nth(0).unwrap().ty - } - _ => unreachable!("must be list"), - }; - let size_t = generator.get_size_type(ctx.ctx); - let ty = if len == 0 - && matches!(&*ctx.unifier.get_ty_immutable(elem_ty), TypeEnum::TVar { .. }) - { - // The default type for zero-length lists of unknown element type is size_t - size_t.into() - } else { - ctx.get_llvm_type(generator, elem_ty) - }; - let arr_ty = ctx - .ctx - .struct_type(&[ty.ptr_type(AddressSpace::default()).into(), size_t.into()], false); - - { - if self.global_value_ids.read().contains_key(&id) { - let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module.add_global(arr_ty, Some(AddressSpace::default()), &id_str) - }); - return Ok(Some(global.as_pointer_value().into())); - } - self.global_value_ids.write().insert(id, obj.into()); - } - - let arr: Result>, _> = (0..len) - .map(|i| { - obj.get_item(i).and_then(|elem| { - self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + let arr: Result>, _> = (0..len) + .map(|i| { + obj.get_item(i).and_then(|elem| { + self.get_obj_value(py, elem, ctx, generator, elem_ty).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) }) }) - .collect(); - let arr = arr?.unwrap(); + }) + .collect(); + let arr = arr?.unwrap(); - let arr_global = ctx.module.add_global( - ty.array_type(len as u32), - Some(AddressSpace::default()), - &(id_str.clone() + "_"), - ); - let arr: BasicValueEnum = if ty.is_int_type() { - let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_int_value).collect(); - ty.into_int_type().const_array(&arr) - } else if ty.is_float_type() { - let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_float_value).collect(); - ty.into_float_type().const_array(&arr) - } else if ty.is_array_type() { - let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_array_value).collect(); - ty.into_array_type().const_array(&arr) - } else if ty.is_struct_type() { - let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_struct_value).collect(); - ty.into_struct_type().const_array(&arr) - } else if ty.is_pointer_type() { - let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_pointer_value).collect(); - ty.into_pointer_type().const_array(&arr) - } else { - unreachable!() - } - .into(); - arr_global.set_initializer(&arr); + let arr_global = ctx.module.add_global( + ty.array_type(len as u32), + Some(AddressSpace::default()), + &(id_str.clone() + "_"), + ); + let arr: BasicValueEnum = if ty.is_int_type() { + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_int_value).collect(); + ty.into_int_type().const_array(&arr) + } else if ty.is_float_type() { + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_float_value).collect(); + ty.into_float_type().const_array(&arr) + } else if ty.is_array_type() { + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_array_value).collect(); + ty.into_array_type().const_array(&arr) + } else if ty.is_struct_type() { + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_struct_value).collect(); + ty.into_struct_type().const_array(&arr) + } else if ty.is_pointer_type() { + let arr: Vec<_> = arr.into_iter().map(BasicValueEnum::into_pointer_value).collect(); + ty.into_pointer_type().const_array(&arr) + } else { + unreachable!() + } + .into(); + arr_global.set_initializer(&arr); - let val = arr_ty.const_named_struct(&[ - arr_global - .as_pointer_value() - .const_cast(ty.ptr_type(AddressSpace::default())) - .into(), - size_t.const_int(len as u64, false).into(), - ]); + let val = arr_ty.const_named_struct(&[ + arr_global.as_pointer_value().const_cast(ty.ptr_type(AddressSpace::default())).into(), + size_t.const_int(len as u64, false).into(), + ]); - let global = ctx.module.add_global(arr_ty, Some(AddressSpace::default()), &id_str); - global.set_initializer(&val); + let global = ctx.module.add_global(arr_ty, Some(AddressSpace::default()), &id_str); + global.set_initializer(&val); - Ok(Some(global.as_pointer_value().into())) - } else if ty_id == self.primitive_ids.ndarray { - let id_str = id.to_string(); + Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.ndarray { + let id_str = id.to_string(); - if let Some(global) = ctx.module.get_global(&id_str) { + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) + { + expected_ty + } else { + unreachable!("must be ndarray") + }; + let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); + let dtype = llvm_ndarray.element_type(); + + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global( + llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + Some(AddressSpace::default()), + &id_str, + ) + }); return Ok(Some(global.as_pointer_value().into())); } + self.global_value_ids.write().insert(id, obj.into()); + } - let ndarray_ty = if matches!(&*ctx.unifier.get_ty_immutable(expected_ty), TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id()) - { - expected_ty - } else { - unreachable!("must be ndarray") - }; - let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ndarray_ty); + let ndims = llvm_ndarray.ndims(); - let llvm_i8 = ctx.ctx.i8_type(); - let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); - let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_ndarray = NDArrayType::from_unifier_type(generator, ctx, ndarray_ty); - let dtype = llvm_ndarray.element_type(); + // Obtain the shape of the ndarray + let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; + assert_eq!(shape_tuple.len(), ndims as usize); - { - if self.global_value_ids.read().contains_key(&id) { - let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), - Some(AddressSpace::default()), - &id_str, - ) - }); - return Ok(Some(global.as_pointer_value().into())); - } - self.global_value_ids.write().insert(id, obj.into()); - } + // The Rust type inferencer cannot figure this out + let shape_values = shape_tuple + .iter() + .enumerate() + .map(|(i, elem)| { + let value = self + .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) + })? + .unwrap(); + let value = + ctx.builder.build_int_z_extend(value.into_int_value(), llvm_usize, "").unwrap(); + Ok(value) + }) + .collect::, PyErr>>()?; - let ndims = llvm_ndarray.ndims(); + // Also use this opportunity to get the constant values of `shape_values` for calculating strides. + let shape_u64s = shape_values + .iter() + .map(|dim| { + assert!(dim.is_const()); + dim.get_zero_extended_constant().unwrap() + }) + .collect_vec(); + let shape_values = llvm_usize.const_array(&shape_values); - // Obtain the shape of the ndarray - let shape_tuple: &PyTuple = obj.getattr("shape")?.downcast()?; - assert_eq!(shape_tuple.len(), ndims as usize); + // create a global for ndarray.shape and initialize it using the shape + let shape_global = ctx.module.add_global( + llvm_usize.array_type(ndims as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".shape"), + ); + shape_global.set_initializer(&shape_values); - // The Rust type inferencer cannot figure this out - let shape_values = shape_tuple - .iter() - .enumerate() - .map(|(i, elem)| { + // Obtain the (flattened) elements of the ndarray + let sz: usize = obj.getattr("size")?.extract()?; + let data: Vec<_> = (0..sz) + .map(|i| { + obj.getattr("flat")?.get_item(i).and_then(|elem| { let value = self - .get_obj_value(py, elem, ctx, generator, ctx.primitives.usize()) + .get_obj_value(py, elem, ctx, generator, ndarray_dtype) .map_err(|e| { super::CompileError::new_err(format!("Error getting element {i}: {e}")) })? .unwrap(); - let value = ctx - .builder - .build_int_z_extend(value.into_int_value(), llvm_usize, "") - .unwrap(); + + assert_eq!(value.get_type(), dtype); Ok(value) }) - .collect::, PyErr>>()?; + }) + .try_collect()?; + let data = data.into_iter(); + let data = match dtype { + BasicTypeEnum::ArrayType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) + } - // Also use this opportunity to get the constant values of `shape_values` for calculating strides. - let shape_u64s = shape_values - .iter() - .map(|dim| { - assert!(dim.is_const()); - dim.get_zero_extended_constant().unwrap() - }) - .collect_vec(); - let shape_values = llvm_usize.const_array(&shape_values); + BasicTypeEnum::FloatType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) + } - // create a global for ndarray.shape and initialize it using the shape - let shape_global = ctx.module.add_global( - llvm_usize.array_type(ndims as u32), - Some(AddressSpace::default()), - &(id_str.clone() + ".shape"), - ); - shape_global.set_initializer(&shape_values); + BasicTypeEnum::IntType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) + } - // Obtain the (flattened) elements of the ndarray - let sz: usize = obj.getattr("size")?.extract()?; - let data: Vec<_> = (0..sz) - .map(|i| { - obj.getattr("flat")?.get_item(i).and_then(|elem| { - let value = self - .get_obj_value(py, elem, ctx, generator, ndarray_dtype) - .map_err(|e| { - super::CompileError::new_err(format!( - "Error getting element {i}: {e}" - )) - })? - .unwrap(); + BasicTypeEnum::PointerType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) + } - assert_eq!(value.get_type(), dtype); - Ok(value) - }) - }) - .try_collect()?; - let data = data.into_iter(); - let data = match dtype { - BasicTypeEnum::ArrayType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_array_value).collect_vec()) - } + BasicTypeEnum::StructType(ty) => { + ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) + } - BasicTypeEnum::FloatType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_float_value).collect_vec()) - } + BasicTypeEnum::VectorType(_) => unreachable!(), + }; - BasicTypeEnum::IntType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_int_value).collect_vec()) - } + // create a global for ndarray.data and initialize it using the elements + // + // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. + // We will have to cast it to an `u8*` later. + let data_global = ctx.module.add_global( + dtype.array_type(sz as u32), + Some(AddressSpace::default()), + &(id_str.clone() + ".data"), + ); + data_global.set_initializer(&data); - BasicTypeEnum::PointerType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_pointer_value).collect_vec()) - } + // Get the constant itemsize. + // + // NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size` + // will always return a constant size. + let itemsize = ctx + .registry + .llvm_options + .create_target_machine() + .map(|tm| tm.get_target_data().get_store_size(&dtype)) + .unwrap(); + assert_ne!(itemsize, 0); - BasicTypeEnum::StructType(ty) => { - ty.const_array(&data.map(BasicValueEnum::into_struct_value).collect_vec()) - } + // Create the strides needed for ndarray.strides + let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); + let strides = + strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec(); + let strides = llvm_usize.const_array(&strides); - BasicTypeEnum::VectorType(_) => unreachable!(), - }; + // create a global for ndarray.strides and initialize it + let strides_global = ctx.module.add_global( + llvm_usize.array_type(ndims as u32), + Some(AddressSpace::default()), + &format!("${id_str}.strides"), + ); + strides_global.set_initializer(&strides); - // create a global for ndarray.data and initialize it using the elements - // - // NOTE: NDArray's `data` is `u8*`. Here, `data_global` is an array of `dtype`. - // We will have to cast it to an `u8*` later. - let data_global = ctx.module.add_global( - dtype.array_type(sz as u32), - Some(AddressSpace::default()), - &(id_str.clone() + ".data"), - ); - data_global.set_initializer(&data); + // create a global for the ndarray object and initialize it - // Get the constant itemsize. - // - // NOTE: dtype.size_of() may return a non-constant, where `TargetData::get_store_size` - // will always return a constant size. - let itemsize = ctx - .registry - .llvm_options - .create_target_machine() - .map(|tm| tm.get_target_data().get_store_size(&dtype)) - .unwrap(); - assert_ne!(itemsize, 0); + // NOTE: data_global is an array of dtype, we want a `u8*`. + let ndarray_data = data_global.as_pointer_value(); + let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); - // Create the strides needed for ndarray.strides - let strides = make_contiguous_strides(itemsize, ndims, &shape_u64s); - let strides = - strides.into_iter().map(|stride| llvm_usize.const_int(stride, false)).collect_vec(); - let strides = llvm_usize.const_array(&strides); + let ndarray_itemsize = llvm_usize.const_int(itemsize, false); - // create a global for ndarray.strides and initialize it - let strides_global = ctx.module.add_global( - llvm_usize.array_type(ndims as u32), - Some(AddressSpace::default()), - &format!("${id_str}.strides"), - ); - strides_global.set_initializer(&strides); + let ndarray_ndims = llvm_usize.const_int(ndims, false); - // create a global for the ndarray object and initialize it + // calling as_pointer_value on shape and strides returns [i64 x ndims]* + // convert into i64* to conform with expected layout of ndarray - // NOTE: data_global is an array of dtype, we want a `u8*`. - let ndarray_data = data_global.as_pointer_value(); - let ndarray_data = ctx.builder.build_pointer_cast(ndarray_data, llvm_pi8, "").unwrap(); + let ndarray_shape = shape_global.as_pointer_value(); + let ndarray_shape = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_shape, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; - let ndarray_itemsize = llvm_usize.const_int(itemsize, false); + let ndarray_strides = strides_global.as_pointer_value(); + let ndarray_strides = unsafe { + ctx.builder + .build_in_bounds_gep( + ndarray_strides, + &[llvm_usize.const_zero(), llvm_usize.const_zero()], + "", + ) + .unwrap() + }; - let ndarray_ndims = llvm_usize.const_int(ndims, false); - - // calling as_pointer_value on shape and strides returns [i64 x ndims]* - // convert into i64* to conform with expected layout of ndarray - - let ndarray_shape = shape_global.as_pointer_value(); - let ndarray_shape = unsafe { - ctx.builder - .build_in_bounds_gep( - ndarray_shape, - &[llvm_usize.const_zero(), llvm_usize.const_zero()], - "", - ) - .unwrap() - }; - - let ndarray_strides = strides_global.as_pointer_value(); - let ndarray_strides = unsafe { - ctx.builder - .build_in_bounds_gep( - ndarray_strides, - &[llvm_usize.const_zero(), llvm_usize.const_zero()], - "", - ) - .unwrap() - }; - - let ndarray = llvm_ndarray - .as_base_type() - .get_element_type() - .into_struct_type() - .const_named_struct(&[ + let ndarray = + llvm_ndarray.as_base_type().get_element_type().into_struct_type().const_named_struct( + &[ ndarray_itemsize.into(), ndarray_ndims.into(), ndarray_shape.into(), ndarray_strides.into(), ndarray_data.into(), - ]); - - let ndarray_global = ctx.module.add_global( - llvm_ndarray.as_base_type().get_element_type().into_struct_type(), - Some(AddressSpace::default()), - &id_str, + ], ); - ndarray_global.set_initializer(&ndarray); - Ok(Some(ndarray_global.as_pointer_value().into())) - } else if ty_id == self.primitive_ids.tuple { - let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); - let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { - unreachable!() - }; + let ndarray_global = ctx.module.add_global( + llvm_ndarray.as_base_type().get_element_type().into_struct_type(), + Some(AddressSpace::default()), + &id_str, + ); + ndarray_global.set_initializer(&ndarray); - let tup_tys = ty.iter(); - let elements: &PyTuple = obj.downcast()?; - assert_eq!(elements.len(), tup_tys.len()); - let val: Result>, _> = elements - .iter() - .enumerate() - .zip(tup_tys) - .map(|((i, elem), ty)| { - self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| { - super::CompileError::new_err(format!("Error getting element {i}: {e}")) - }) + Ok(Some(ndarray_global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.tuple { + let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); + let TypeEnum::TTuple { ty, is_vararg_ctx: false } = expected_ty_enum.as_ref() else { + unreachable!() + }; + + let tup_tys = ty.iter(); + let elements: &PyTuple = obj.downcast()?; + assert_eq!(elements.len(), tup_tys.len()); + let val: Result>, _> = elements + .iter() + .enumerate() + .zip(tup_tys) + .map(|((i, elem), ty)| { + self.get_obj_value(py, elem, ctx, generator, *ty).map_err(|e| { + super::CompileError::new_err(format!("Error getting element {i}: {e}")) }) - .collect(); - let val = val?.unwrap(); - let val = ctx.ctx.const_struct(&val, false); - Ok(Some(val.into())) - } else if ty_id == self.primitive_ids.option { - let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { - TypeEnum::TObj { obj_id, params, .. } - if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => - { - *params.iter().next().unwrap().1 - } - _ => unreachable!("must be option type"), - }; - if id == self.primitive_ids.none { - // for option type, just a null ptr - Ok(Some( - ctx.get_llvm_type(generator, option_val_ty) - .ptr_type(AddressSpace::default()) - .const_null() - .into(), - )) - } else { - match self - .get_obj_value( - py, - obj.getattr("_nac3_option").unwrap(), - ctx, - generator, - option_val_ty, - ) - .map_err(|e| { - super::CompileError::new_err(format!( - "Error getting value of Option object: {e}" - )) - })? { - Some(v) => { - let global_str = format!("{id}_option"); - { - if self.global_value_ids.read().contains_key(&id) { - let global = - ctx.module.get_global(&global_str).unwrap_or_else(|| { - ctx.module.add_global( - v.get_type(), - Some(AddressSpace::default()), - &global_str, - ) - }); - return Ok(Some(global.as_pointer_value().into())); - } - self.global_value_ids.write().insert(id, obj.into()); - } - let global = ctx.module.add_global( - v.get_type(), - Some(AddressSpace::default()), - &global_str, - ); - global.set_initializer(&v); - Ok(Some(global.as_pointer_value().into())) - } - None => Ok(None), - } - } - } else { - let id_str = id.to_string(); - - if let Some(global) = ctx.module.get_global(&id_str) { - return Ok(Some(global.as_pointer_value().into())); - } - - let top_level_defs = ctx.top_level.definitions.read(); - let ty = self - .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? - .unwrap(); - let ty = ctx - .get_llvm_type(generator, ty) - .into_pointer_type() - .get_element_type() - .into_struct_type(); + }) + .collect(); + let val = val?.unwrap(); + let val = ctx.ctx.const_struct(&val, false); + Ok(Some(val.into())) + } else if ty_id == self.primitive_ids.option { + let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == ctx.primitives.option.obj_id(&ctx.unifier).unwrap() => { - if self.global_value_ids.read().contains_key(&id) { - let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { - ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) - }); - return Ok(Some(global.as_pointer_value().into())); - } - self.global_value_ids.write().insert(id, obj.into()); + *params.iter().next().unwrap().1 } - // should be classes - let definition = - top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read(); - let TopLevelDef::Class { fields, .. } = &*definition else { unreachable!() }; + _ => unreachable!("must be option type"), + }; + if id == self.primitive_ids.none { + // for option type, just a null ptr + Ok(Some( + ctx.get_llvm_type(generator, option_val_ty) + .ptr_type(AddressSpace::default()) + .const_null() + .into(), + )) + } else { + match self + .get_obj_value( + py, + obj.getattr("_nac3_option").unwrap(), + ctx, + generator, + option_val_ty, + ) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting value of Option object: {e}" + )) + })? { + Some(v) => { + let global_str = format!("{id}_option"); + { + if self.global_value_ids.read().contains_key(&id) { + let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { + ctx.module.add_global( + v.get_type(), + Some(AddressSpace::default()), + &global_str, + ) + }); + return Ok(Some(global.as_pointer_value().into())); + } + self.global_value_ids.write().insert(id, obj.into()); + } + let global = ctx.module.add_global( + v.get_type(), + Some(AddressSpace::default()), + &global_str, + ); + global.set_initializer(&v); + Ok(Some(global.as_pointer_value().into())) + } + None => Ok(None), + } + } + } else { + let id_str = id.to_string(); - let values: Result>, _> = fields - .iter() - .map(|(name, ty, _)| { - self.get_obj_value( - py, - obj.getattr(name.to_string().as_str())?, - ctx, - generator, - *ty, - ) - .map_err(|e| { - super::CompileError::new_err(format!("Error getting field {name}: {e}")) - }) - }) - .collect(); - let values = values?; - if let Some(values) = values { - let val = ty.const_named_struct(&values); + if let Some(global) = ctx.module.get_global(&id_str) { + return Ok(Some(global.as_pointer_value().into())); + } + + let top_level_defs = ctx.top_level.definitions.read(); + let ty = self + .get_obj_type(py, obj, &mut ctx.unifier, &top_level_defs, &ctx.primitives)? + .unwrap(); + let ty = ctx + .get_llvm_type(generator, ty) + .into_pointer_type() + .get_element_type() + .into_struct_type(); + { + if self.global_value_ids.read().contains_key(&id) { let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) }); - global.set_initializer(&val); - Ok(Some(global.as_pointer_value().into())) - } else { - Ok(None) + return Ok(Some(global.as_pointer_value().into())); } + self.global_value_ids.write().insert(id, obj.into()); + } + // should be classes + let definition = + top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read(); + let TopLevelDef::Class { fields, .. } = &*definition else { unreachable!() }; + + let values: Result>, _> = fields + .iter() + .map(|(name, ty, _)| { + self.get_obj_value(py, obj.getattr(name.to_string().as_str())?, ctx, generator, *ty) + .map_err(|e| { + super::CompileError::new_err(format!("Error getting field {name}: {e}")) + }) + }) + .collect(); + let values = values?; + if let Some(values) = values { + let val = ty.const_named_struct(&values); + let global = ctx.module.get_global(&id_str).unwrap_or_else(|| { + ctx.module.add_global(ty, Some(AddressSpace::default()), &id_str) + }); + global.set_initializer(&val); + Ok(Some(global.as_pointer_value().into())) + } else { + Ok(None) } } +} - fn get_default_param_obj_value( - &self, - py: Python, - obj: &PyAny, - ) -> PyResult> { - let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; - let ty_id: u64 = - self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; - Ok(if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { - let val: i32 = obj.extract()?; - Ok(SymbolValue::I32(val)) - } else if ty_id == self.primitive_ids.int64 { - let val: i64 = obj.extract()?; - Ok(SymbolValue::I64(val)) - } else if ty_id == self.primitive_ids.uint32 { - let val: u32 = obj.extract()?; - Ok(SymbolValue::U32(val)) - } else if ty_id == self.primitive_ids.uint64 { - let val: u64 = obj.extract()?; - Ok(SymbolValue::U64(val)) - } else if ty_id == self.primitive_ids.bool { - let val: bool = obj.extract()?; - Ok(SymbolValue::Bool(val)) - } else if ty_id == self.primitive_ids.np_bool_ { - let val: bool = obj.call_method("__bool__", (), None)?.extract()?; - Ok(SymbolValue::Bool(val)) - } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { - let val: String = obj.extract()?; - Ok(SymbolValue::Str(val)) - } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { - let val: f64 = obj.extract()?; - Ok(SymbolValue::Double(val)) - } else if ty_id == self.primitive_ids.tuple { - let elements: &PyTuple = obj.downcast()?; - let elements: Result, String>, _> = - elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect(); - elements?.map(SymbolValue::Tuple) - } else if ty_id == self.primitive_ids.option { - if id == self.primitive_ids.none { - Ok(SymbolValue::OptionNone) - } else { - self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())? - .map(|v| SymbolValue::OptionSome(Box::new(v))) - } +fn get_default_param_obj_value( + &self, + py: Python, + obj: &PyAny, +) -> PyResult> { + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; + let ty_id: u64 = + self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; + Ok(if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { + let val: i32 = obj.extract()?; + Ok(SymbolValue::I32(val)) + } else if ty_id == self.primitive_ids.int64 { + let val: i64 = obj.extract()?; + Ok(SymbolValue::I64(val)) + } else if ty_id == self.primitive_ids.uint32 { + let val: u32 = obj.extract()?; + Ok(SymbolValue::U32(val)) + } else if ty_id == self.primitive_ids.uint64 { + let val: u64 = obj.extract()?; + Ok(SymbolValue::U64(val)) + } else if ty_id == self.primitive_ids.bool { + let val: bool = obj.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.np_bool_ { + let val: bool = obj.call_method("__bool__", (), None)?.extract()?; + Ok(SymbolValue::Bool(val)) + } else if ty_id == self.primitive_ids.string || ty_id == self.primitive_ids.np_str_ { + let val: String = obj.extract()?; + Ok(SymbolValue::Str(val)) + } else if ty_id == self.primitive_ids.float || ty_id == self.primitive_ids.float64 { + let val: f64 = obj.extract()?; + Ok(SymbolValue::Double(val)) + } else if ty_id == self.primitive_ids.tuple { + let elements: &PyTuple = obj.downcast()?; + let elements: Result, String>, _> = + elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect(); + elements?.map(SymbolValue::Tuple) + } else if ty_id == self.primitive_ids.option { + if id == self.primitive_ids.none { + Ok(SymbolValue::OptionNone) } else { - Err("only primitives values, option and tuple can be default parameter value".into()) - }) - } + self.get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())? + .map(|v| SymbolValue::OptionSome(Box::new(v))) + } + } else { + Err("only primitives values, option and tuple can be default parameter value".into()) + }) +} impl SymbolResolver for Resolver { fn get_default_param_value(&self, expr: &ast::Expr) -> Option {