From 9b988647edee6d252e5d87d721d45b7c652d600e Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 16 Jul 2024 19:01:38 +0800 Subject: [PATCH] core/toplevel/builtins: Extract len() into builtin function --- nac3core/src/codegen/builtin_fns.rs | 78 +++++++++++++++++++++++++-- nac3core/src/toplevel/builtins.rs | 82 +---------------------------- 2 files changed, 77 insertions(+), 83 deletions(-) diff --git a/nac3core/src/codegen/builtin_fns.rs b/nac3core/src/codegen/builtin_fns.rs index 311fd35..640eb94 100644 --- a/nac3core/src/codegen/builtin_fns.rs +++ b/nac3core/src/codegen/builtin_fns.rs @@ -1,17 +1,20 @@ use inkwell::types::BasicTypeEnum; -use inkwell::values::{BasicValue, BasicValueEnum, PointerValue}; +use inkwell::values::{BasicValue, BasicValueEnum, IntValue, PointerValue}; use inkwell::{FloatPredicate, IntPredicate, OptimizationLevel}; use itertools::Itertools; use crate::codegen::classes::{ - NDArrayValue, ProxyValue, UntypedArrayLikeAccessor, UntypedArrayLikeMutator, + ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor, + UntypedArrayLikeAccessor, UntypedArrayLikeMutator, }; +use crate::codegen::expr::destructure_range; +use crate::codegen::irrt::calculate_len_for_slice_range; use crate::codegen::numpy::ndarray_elementwise_unaryop_impl; use crate::codegen::stmt::gen_for_callback_incrementing; use crate::codegen::{extern_fns, irrt, llvm_intrinsics, numpy, CodeGenContext, CodeGenerator}; use crate::toplevel::helper::PrimDef; use crate::toplevel::numpy::unpack_ndarray_var_tys; -use crate::typecheck::typedef::Type; +use crate::typecheck::typedef::{Type, TypeEnum}; /// Shorthand for [`unreachable!()`] when a type of argument is not supported. /// @@ -23,6 +26,75 @@ fn unsupported_type(ctx: &CodeGenContext<'_, '_>, fn_name: &str, tys: &[Type]) - ) } +/// Invokes the `len` builtin function. +pub fn call_len<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + n: (Type, BasicValueEnum<'ctx>), +) -> Result, String> { + let range_ty = ctx.primitives.range; + let (arg_ty, arg) = n; + + Ok(if ctx.unifier.unioned(arg_ty, range_ty) { + let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); + let (start, end, step) = destructure_range(ctx, arg); + calculate_len_for_slice_range(generator, ctx, start, end, step) + } else { + match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let len = ctx + .build_gep_and_load( + arg.into_pointer_value(), + &[zero, int32.const_int(1, false)], + None, + ) + .into_int_value(); + if len.get_type().get_bit_width() == 32 { + len + } else { + ctx.builder.build_int_truncate(len, int32, "len2i32").unwrap() + } + } + TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let arg = NDArrayValue::from_ptr_val(arg.into_pointer_value(), llvm_usize, None); + + let ndims = arg.dim_sizes().size(ctx, generator); + ctx.make_assert( + generator, + ctx.builder + .build_int_compare(IntPredicate::NE, ndims, llvm_usize.const_zero(), "") + .unwrap(), + "0:TypeError", + "len() of unsized object", + [None, None, None], + ctx.current_loc, + ); + + let len = unsafe { + arg.dim_sizes().get_typed_unchecked( + ctx, + generator, + &llvm_usize.const_zero(), + None, + ) + }; + + if len.get_type().get_bit_width() == 32 { + len + } else { + ctx.builder.build_int_truncate(len, llvm_i32, "len").unwrap() + } + } + _ => unreachable!(), + } + }) +} + /// Invokes the `int32` builtin function. pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index be8687e..b59bed8 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -14,9 +14,7 @@ use strum::IntoEnumIterator; use crate::{ codegen::{ builtin_fns, - classes::{ArrayLikeValue, NDArrayValue, ProxyValue, RangeValue, TypedArrayLikeAccessor}, - expr::destructure_range, - irrt::*, + classes::{ProxyValue, RangeValue}, numpy::*, stmt::exn_constructor, }, @@ -1503,86 +1501,10 @@ impl<'a> BuiltinBuilder<'a> { resolver: None, codegen_callback: Some(Arc::new(GenCall::new(Box::new( move |ctx, _, fun, args, generator| { - let range_ty = ctx.primitives.range; let arg_ty = fun.0.args[0].ty; let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?; - Ok(if ctx.unifier.unioned(arg_ty, range_ty) { - let arg = RangeValue::from_ptr_val(arg.into_pointer_value(), Some("range")); - let (start, end, step) = destructure_range(ctx, arg); - Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) - } else { - match &*ctx.unifier.get_ty_immutable(arg_ty) { - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::List.id() => { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, int32.const_int(1, false)], - None, - ) - .into_int_value(); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some( - ctx.builder - .build_int_truncate(len, int32, "len2i32") - .map(Into::into) - .unwrap(), - ) - } - } - TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => { - let llvm_i32 = ctx.ctx.i32_type(); - let llvm_usize = generator.get_size_type(ctx.ctx); - let arg = NDArrayValue::from_ptr_val( - arg.into_pointer_value(), - llvm_usize, - None, - ); - - let ndims = arg.dim_sizes().size(ctx, generator); - ctx.make_assert( - generator, - ctx.builder - .build_int_compare( - IntPredicate::NE, - ndims, - llvm_usize.const_zero(), - "", - ) - .unwrap(), - "0:TypeError", - &format!("{name}() of unsized object", name = prim.name()), - [None, None, None], - ctx.current_loc, - ); - - let len = unsafe { - arg.dim_sizes().get_typed_unchecked( - ctx, - generator, - &llvm_usize.const_zero(), - None, - ) - }; - - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some( - ctx.builder - .build_int_truncate(len, llvm_i32, "len") - .map(Into::into) - .unwrap(), - ) - } - } - _ => unreachable!(), - } - }) + builtin_fns::call_len(generator, ctx, (arg_ty, arg)).map(|ret| Some(ret.into())) }, )))), loc: None,