From 9a98cde595b26c360f6ce8f2ac4cec913f8e15fc Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 11 Mar 2024 15:09:33 +0800 Subject: [PATCH] core: Extract codegen portion of gen_*op_expr This allows *ops to be generated internally using LLVM values as input. Required in a future change. --- nac3core/src/codegen/expr.rs | 403 ++++++++++++++++++++++------------- 1 file changed, 256 insertions(+), 147 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index ea8211a8c..56e6e0b90 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1090,34 +1090,22 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>( Ok(Some(list.as_ptr_value().into())) } -/// Generates LLVM IR for a [binary operator expression][expr]. -/// -/// * `left` - The left-hand side of the binary operator. -/// * `op` - The operator applied on the operands. -/// * `right` - The right-hand side of the binary operator. -/// * `loc` - The location of the full expression. -/// * `is_aug_assign` - Whether the binary operator expression is also an assignment operator. -pub fn gen_binop_expr<'ctx, G: CodeGenerator>( +/// Generates LLVM IR for a binary operator expression using the [`Type`] and +/// [LLVM value][`BasicValueEnum`] of the operands. +pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( generator: &mut G, ctx: &mut CodeGenContext<'ctx, '_>, - left: &Expr>, + left: (&Option, BasicValueEnum<'ctx>), op: &Operator, - right: &Expr>, + right: (&Option, BasicValueEnum<'ctx>), loc: Location, is_aug_assign: bool, ) -> Result>, String> { - let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); - let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); - let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { - v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? - } else { - return Ok(None) - }; - let right_val = if let Some(v) = generator.gen_expr(ctx, right)? { - v.to_basic_value_enum(ctx, generator, right.custom.unwrap())? - } else { - return Ok(None) - }; + let (left_ty, left_val) = left; + let (right_ty, right_val) = right; + + let ty1 = ctx.unifier.get_representative(left_ty.unwrap()); + let ty2 = ctx.unifier.get_representative(right_ty.unwrap()); // we can directly compare the types, because we've got their representatives // which would be unchanged until further unification, which we would never do @@ -1142,7 +1130,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( ); Ok(Some(res.into())) } else { - let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap()); + let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { unreachable!("must be tobj") }; @@ -1162,7 +1150,7 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( let signature = if let Some(call) = ctx.calls.get(&loc.into()) { ctx.unifier.get_call_signature(*call).unwrap() } else { - let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap()); + let left_enum_ty = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else { unreachable!("must be tobj") }; @@ -1187,13 +1175,254 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>( generator .gen_call( ctx, - Some((left.custom.unwrap(), left_val.into())), + Some((left_ty.unwrap(), left_val.into())), (&signature, fun_id), vec![(None, right_val.into())], ).map(|f| f.map(Into::into)) } } +/// Generates LLVM IR for a binary operator expression. +/// +/// * `left` - The left-hand side of the binary operator. +/// * `op` - The operator applied on the operands. +/// * `right` - The right-hand side of the binary operator. +/// * `loc` - The location of the full expression. +/// * `is_aug_assign` - Whether the binary operator expression is also an assignment operator. +pub fn gen_binop_expr<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + left: &Expr>, + op: &Operator, + right: &Expr>, + loc: Location, + is_aug_assign: bool, +) -> Result>, String> { + let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { + v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? + } else { + return Ok(None) + }; + let right_val = if let Some(v) = generator.gen_expr(ctx, right)? { + v.to_basic_value_enum(ctx, generator, right.custom.unwrap())? + } else { + return Ok(None) + }; + + gen_binop_expr_with_values( + generator, + ctx, + (&left.custom, left_val), + op, + (&right.custom, right_val), + loc, + is_aug_assign, + ) +} + +/// Generates LLVM IR for a unary operator expression using the [`Type`] and +/// [LLVM value][`BasicValueEnum`] of the operands. +pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>( + _generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + op: &ast::Unaryop, + operand: (&Option, BasicValueEnum<'ctx>), +) -> Result>, String> { + let (ty, val) = operand; + let ty = ctx.unifier.get_representative(ty.unwrap()); + + Ok(Some(if ty == ctx.primitives.bool { + let val = val.into_int_value(); + match op { + ast::Unaryop::Invert | ast::Unaryop::Not => { + ctx.builder.build_not(val, "not").map(Into::into).unwrap() + } + _ => val.into(), + } + } else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) { + let val = val.into_int_value(); + match op { + ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(), + ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(), + ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(), + ast::Unaryop::UAdd => val.into(), + } + } else if ty == ctx.primitives.float { + let val = val.into_float_value(); + match op { + ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(), + ast::Unaryop::Not => ctx + .builder + .build_float_compare( + inkwell::FloatPredicate::OEQ, + val, + val.get_type().const_zero(), + "not", + ) + .map(Into::into) + .unwrap(), + _ => val.into(), + } + } else { + unimplemented!() + })) +} + +/// Generates LLVM IR for a unary operator expression. +/// +/// * `op` - The operator applied on the operand. +/// * `operand` - The unary operand. +pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + op: &ast::Unaryop, + operand: &Expr>, +) -> Result>, String> { + let val = if let Some(v) = generator.gen_expr(ctx, operand)? { + v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())? + } else { + return Ok(None) + }; + + gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val)) +} + +/// Generates LLVM IR for a comparison operator expression using the [`Type`] and +/// [LLVM value][`BasicValueEnum`] of the operands. +pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>( + _generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + left: (Option, BasicValueEnum<'ctx>), + ops: &[ast::Cmpop], + comparators: &[(Option, BasicValueEnum<'ctx>)], +) -> Result>, String> { + let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),) + .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { + let (left_ty, lhs) = lhs; + let (right_ty, rhs) = rhs; + + let left_ty = ctx.unifier.get_representative(left_ty.unwrap()); + let right_ty = ctx.unifier.get_representative(right_ty.unwrap()); + + let current = + if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool] + .contains(&left_ty) + { + assert!(ctx.unifier.unioned(left_ty, right_ty)); + + let use_unsigned_ops = [ + ctx.primitives.uint32, + ctx.primitives.uint64, + ].contains(&left_ty); + + let lhs = lhs.into_int_value(); + let rhs = rhs.into_int_value(); + + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, + ast::Cmpop::NotEq => IntPredicate::NE, + _ if left_ty == ctx.primitives.bool => unreachable!(), + ast::Cmpop::Lt => if use_unsigned_ops { + IntPredicate::ULT + } else { + IntPredicate::SLT + }, + ast::Cmpop::LtE => if use_unsigned_ops { + IntPredicate::ULE + } else { + IntPredicate::SLE + }, + ast::Cmpop::Gt => if use_unsigned_ops { + IntPredicate::UGT + } else { + IntPredicate::SGT + }, + ast::Cmpop::GtE => if use_unsigned_ops { + IntPredicate::UGE + } else { + IntPredicate::SGE + }, + _ => unreachable!(), + }; + + ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap() + } else if left_ty == ctx.primitives.float { + assert!(ctx.unifier.unioned(left_ty, right_ty)); + + let lhs = lhs.into_float_value(); + let rhs = rhs.into_float_value(); + + let op = match op { + ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, + ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, + ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, + ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, + ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, + ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, + _ => unreachable!(), + }; + ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() + } else { + unimplemented!() + }; + Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) + })?; + + Ok(Some(match cmp_val { + Some(v) => v.into(), + None => return Ok(None), + })) +} + +/// Generates LLVM IR for a comparison operator expression. +/// +/// * `left` - The left-hand side of the comparison operator. +/// * `ops` - The (possibly chained) operators applied on the operands. +/// * `comparators` - The right-hand side of the binary operator. +pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + left: &Expr>, + ops: &[ast::Cmpop], + comparators: &[Expr>], +) -> Result>, String> { + let left_val = if let Some(v) = generator.gen_expr(ctx, left)? { + v.to_basic_value_enum(ctx, generator, left.custom.unwrap())? + } else { + return Ok(None) + }; + let comparator_vals = comparators.iter() + .map(|cmptor| { + Ok(if let Some(v) = generator.gen_expr(ctx, cmptor)? { + Some((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?)) + } else { + None + }) + }) + .take_while(|v| if let Ok(v) = v { + v.is_some() + } else { + true + }) + .collect::, String>>()?; + let comparator_vals = if comparator_vals.len() == comparators.len() { + comparator_vals + .into_iter() + .map(Option::unwrap) + .collect_vec() + } else { + return Ok(None) + }; + + gen_cmpop_expr_with_values( + generator, + ctx, + (left.custom, left_val), + ops, + comparator_vals.as_slice(), + ) +} + /// Generates code for a subscript expression on an `ndarray`. /// /// * `ty` - The `Type` of the `NDArray` elements. @@ -1570,130 +1799,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( return gen_binop_expr(generator, ctx, left, op, right, expr.location, false); } ExprKind::UnaryOp { op, operand } => { - let ty = ctx.unifier.get_representative(operand.custom.unwrap()); - let val = if let Some(v) = generator.gen_expr(ctx, operand)? { - v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())? - } else { - return Ok(None) - }; - if ty == ctx.primitives.bool { - let val = val.into_int_value(); - match op { - ast::Unaryop::Invert | ast::Unaryop::Not => { - ctx.builder.build_not(val, "not").map(Into::into).unwrap() - } - _ => val.into(), - } - } else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) { - let val = val.into_int_value(); - match op { - ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(), - ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(), - ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(), - ast::Unaryop::UAdd => val.into(), - } - } else if ty == ctx.primitives.float { - let val = val.into_float_value(); - match op { - ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(), - ast::Unaryop::Not => ctx - .builder - .build_float_compare( - inkwell::FloatPredicate::OEQ, - val, - val.get_type().const_zero(), - "not", - ) - .map(Into::into) - .unwrap(), - _ => val.into(), - } - } else { - unimplemented!() - } + return gen_unaryop_expr(generator, ctx, op, operand) } ExprKind::Compare { left, ops, comparators } => { - let cmp_val = izip!(chain(once(left.as_ref()), comparators.iter()), comparators.iter(), ops.iter(),) - .fold(Ok(None), |prev: Result, String>, (lhs, rhs, op)| { - let ty = ctx.unifier.get_representative(lhs.custom.unwrap()); - let current = - if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool] - .contains(&ty) - { - let use_unsigned_ops = [ - ctx.primitives.uint32, - ctx.primitives.uint64, - ].contains(&ty); - - let BasicValueEnum::IntValue(lhs) = (match generator.gen_expr(ctx, lhs)? { - Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?, - None => return Ok(None), - }) else { unreachable!() }; - - let BasicValueEnum::IntValue(rhs) = (match generator.gen_expr(ctx, rhs)? { - Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?, - None => return Ok(None), - }) else { unreachable!() }; - - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ, - ast::Cmpop::NotEq => IntPredicate::NE, - _ if ty == ctx.primitives.bool => unreachable!(), - ast::Cmpop::Lt => if use_unsigned_ops { - IntPredicate::ULT - } else { - IntPredicate::SLT - }, - ast::Cmpop::LtE => if use_unsigned_ops { - IntPredicate::ULE - } else { - IntPredicate::SLE - }, - ast::Cmpop::Gt => if use_unsigned_ops { - IntPredicate::UGT - } else { - IntPredicate::SGT - }, - ast::Cmpop::GtE => if use_unsigned_ops { - IntPredicate::UGE - } else { - IntPredicate::SGE - }, - _ => unreachable!(), - }; - - ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap() - } else if ty == ctx.primitives.float { - let BasicValueEnum::FloatValue(lhs) = (match generator.gen_expr(ctx, lhs)? { - Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?, - None => return Ok(None), - }) else { unreachable!() }; - - let BasicValueEnum::FloatValue(rhs) = (match generator.gen_expr(ctx, rhs)? { - Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?, - None => return Ok(None), - }) else { unreachable!() }; - - let op = match op { - ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ, - ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE, - ast::Cmpop::Lt => inkwell::FloatPredicate::OLT, - ast::Cmpop::LtE => inkwell::FloatPredicate::OLE, - ast::Cmpop::Gt => inkwell::FloatPredicate::OGT, - ast::Cmpop::GtE => inkwell::FloatPredicate::OGE, - _ => unreachable!(), - }; - ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap() - } else { - unimplemented!() - }; - Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current))) - })?; - - match cmp_val { - Some(v) => v.into(), - None => return Ok(None), - } + return gen_cmpop_expr(generator, ctx, left, ops, comparators) } ExprKind::IfExp { test, body, orelse } => { let test = match generator.gen_expr(ctx, test)? {