Compare commits
6 Commits
29ae48faad
...
1cb9a90825
Author | SHA1 | Date |
---|---|---|
David Mak | 1cb9a90825 | |
David Mak | 42c482f897 | |
David Mak | 1a09ea126d | |
David Mak | f0c8f88ce3 | |
David Mak | 0c3e353a11 | |
David Mak | 2f73c96e98 |
|
@ -39,7 +39,6 @@ use inkwell::{
|
|||
types::{AnyType, BasicType, BasicTypeEnum},
|
||||
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}
|
||||
};
|
||||
use inkwell::values::BasicValue;
|
||||
use itertools::{chain, izip, Itertools, Either};
|
||||
use nac3parser::ast::{
|
||||
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
|
||||
|
@ -1154,7 +1153,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(right_val, false),
|
||||
(right_val.into(), false),
|
||||
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
|
@ -1292,294 +1291,6 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
|
|||
)
|
||||
}
|
||||
|
||||
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
op: &ast::Unaryop,
|
||||
operand: (&Option<Type>, BasicValueEnum<'ctx>),
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let (ty, val) = operand;
|
||||
let ty = ctx.unifier.get_representative(ty.unwrap());
|
||||
|
||||
Ok(Some(if ty == ctx.primitives.bool {
|
||||
let val = val.into_int_value();
|
||||
match op {
|
||||
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
||||
ctx.builder.build_not(val, "not").map(Into::into).unwrap()
|
||||
}
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
|
||||
let val = val.into_int_value();
|
||||
match op {
|
||||
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(),
|
||||
ast::Unaryop::UAdd => val.into(),
|
||||
}
|
||||
} else if ty == ctx.primitives.float {
|
||||
let val = val.into_float_value();
|
||||
match op {
|
||||
ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Not => ctx
|
||||
.builder
|
||||
.build_float_compare(
|
||||
inkwell::FloatPredicate::OEQ,
|
||||
val,
|
||||
val.get_type().const_zero(),
|
||||
"not",
|
||||
)
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||
|
||||
let val = NDArrayValue::from_ptr_val(
|
||||
val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
None,
|
||||
val,
|
||||
|generator, ctx, elem_ty, val| {
|
||||
gen_unaryop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
op,
|
||||
(&Some(elem_ty), val)
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
|
||||
},
|
||||
)?;
|
||||
|
||||
res.as_ptr_value().into()
|
||||
} else {
|
||||
unimplemented!()
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
op: &ast::Unaryop,
|
||||
operand: &Expr<Option<Type>>,
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let val = if let Some(v) = generator.gen_expr(ctx, operand)? {
|
||||
v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
|
||||
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
|
||||
}
|
||||
|
||||
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
left: (Option<Type>, BasicValueEnum<'ctx>),
|
||||
ops: &[ast::Cmpop],
|
||||
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
debug_assert_eq!(comparators.len(), ops.len());
|
||||
|
||||
if comparators.len() == 1 {
|
||||
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
|
||||
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
|
||||
|
||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (Some(left_ty), lhs) = left else { unreachable!() };
|
||||
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
||||
let op = ops[0].clone();
|
||||
|
||||
let is_ndarray1 = left_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||
let is_ndarray2 = right_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||
|
||||
return if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
let left_val = NDArrayValue::from_ptr_val(
|
||||
lhs.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(rhs, false),
|
||||
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype1), lhs),
|
||||
&[op.clone()],
|
||||
&[(Some(ndarray_dtype2), rhs)],
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||
&mut ctx.unifier,
|
||||
if is_ndarray1 { left_ty } else { right_ty },
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ctx.primitives.bool,
|
||||
None,
|
||||
(lhs, !is_ndarray1),
|
||||
(rhs, !is_ndarray2),
|
||||
|generator, ctx, elem_ty, (lhs, rhs)| {
|
||||
let val = gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(Some(ndarray_dtype), lhs),
|
||||
&[op.clone()],
|
||||
&[(Some(ndarray_dtype), rhs)],
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
|
||||
|
||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||
let (left_ty, lhs) = lhs;
|
||||
let (right_ty, rhs) = rhs;
|
||||
|
||||
let left_ty = ctx.unifier.get_representative(left_ty.unwrap());
|
||||
let right_ty = ctx.unifier.get_representative(right_ty.unwrap());
|
||||
|
||||
let current =
|
||||
if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool]
|
||||
.contains(&left_ty)
|
||||
{
|
||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||
|
||||
let use_unsigned_ops = [
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
].contains(&left_ty);
|
||||
|
||||
let lhs = lhs.into_int_value();
|
||||
let rhs = rhs.into_int_value();
|
||||
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => IntPredicate::NE,
|
||||
_ if left_ty == ctx.primitives.bool => unreachable!(),
|
||||
ast::Cmpop::Lt => if use_unsigned_ops {
|
||||
IntPredicate::ULT
|
||||
} else {
|
||||
IntPredicate::SLT
|
||||
},
|
||||
ast::Cmpop::LtE => if use_unsigned_ops {
|
||||
IntPredicate::ULE
|
||||
} else {
|
||||
IntPredicate::SLE
|
||||
},
|
||||
ast::Cmpop::Gt => if use_unsigned_ops {
|
||||
IntPredicate::UGT
|
||||
} else {
|
||||
IntPredicate::SGT
|
||||
},
|
||||
ast::Cmpop::GtE => if use_unsigned_ops {
|
||||
IntPredicate::UGE
|
||||
} else {
|
||||
IntPredicate::SGE
|
||||
},
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else if left_ty == ctx.primitives.float {
|
||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||
|
||||
let lhs = lhs.into_float_value();
|
||||
let rhs = rhs.into_float_value();
|
||||
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
|
||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||
})?;
|
||||
|
||||
Ok(Some(match cmp_val {
|
||||
Some(v) => v.into(),
|
||||
None => return Ok(None),
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
left: &Expr<Option<Type>>,
|
||||
ops: &[ast::Cmpop],
|
||||
comparators: &[Expr<Option<Type>>],
|
||||
) -> Result<Option<ValueEnum<'ctx>>, String> {
|
||||
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
|
||||
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
let comparators = {
|
||||
let mut new_comparators = Vec::new();
|
||||
new_comparators.reserve(comparators.len());
|
||||
|
||||
for cmptor in comparators {
|
||||
if let Some(v) = generator.gen_expr(ctx, cmptor)? {
|
||||
new_comparators.push((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?))
|
||||
} else {
|
||||
return Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
new_comparators
|
||||
};
|
||||
|
||||
gen_cmpop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(left.custom, left_val),
|
||||
ops,
|
||||
comparators.as_slice(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates code for a subscript expression on an `ndarray`.
|
||||
///
|
||||
/// * `ty` - The `Type` of the `NDArray` elements.
|
||||
|
@ -1956,10 +1667,130 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
|
|||
return gen_binop_expr(generator, ctx, left, op, right, expr.location, false);
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => {
|
||||
return gen_unaryop_expr(generator, ctx, op, operand)
|
||||
let ty = ctx.unifier.get_representative(operand.custom.unwrap());
|
||||
let val = if let Some(v) = generator.gen_expr(ctx, operand)? {
|
||||
v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())?
|
||||
} else {
|
||||
return Ok(None)
|
||||
};
|
||||
if ty == ctx.primitives.bool {
|
||||
let val = val.into_int_value();
|
||||
match op {
|
||||
ast::Unaryop::Invert | ast::Unaryop::Not => {
|
||||
ctx.builder.build_not(val, "not").map(Into::into).unwrap()
|
||||
}
|
||||
_ => val.into(),
|
||||
}
|
||||
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
|
||||
let val = val.into_int_value();
|
||||
match op {
|
||||
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(),
|
||||
ast::Unaryop::UAdd => val.into(),
|
||||
}
|
||||
} else if ty == ctx.primitives.float {
|
||||
let val = val.into_float_value();
|
||||
match op {
|
||||
ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(),
|
||||
ast::Unaryop::Not => ctx
|
||||
.builder
|
||||
.build_float_compare(
|
||||
inkwell::FloatPredicate::OEQ,
|
||||
val,
|
||||
val.get_type().const_zero(),
|
||||
"not",
|
||||
)
|
||||
.map(Into::into)
|
||||
.unwrap(),
|
||||
_ => val.into(),
|
||||
}
|
||||
} else {
|
||||
unimplemented!()
|
||||
}
|
||||
}
|
||||
ExprKind::Compare { left, ops, comparators } => {
|
||||
return gen_cmpop_expr(generator, ctx, left, &ops, &comparators)
|
||||
let cmp_val = izip!(chain(once(left.as_ref()), comparators.iter()), comparators.iter(), ops.iter(),)
|
||||
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
|
||||
let ty = ctx.unifier.get_representative(lhs.custom.unwrap());
|
||||
let current =
|
||||
if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool]
|
||||
.contains(&ty)
|
||||
{
|
||||
let use_unsigned_ops = [
|
||||
ctx.primitives.uint32,
|
||||
ctx.primitives.uint64,
|
||||
].contains(&ty);
|
||||
|
||||
let BasicValueEnum::IntValue(lhs) = (match generator.gen_expr(ctx, lhs)? {
|
||||
Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?,
|
||||
None => return Ok(None),
|
||||
}) else { unreachable!() };
|
||||
|
||||
let BasicValueEnum::IntValue(rhs) = (match generator.gen_expr(ctx, rhs)? {
|
||||
Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?,
|
||||
None => return Ok(None),
|
||||
}) else { unreachable!() };
|
||||
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
|
||||
ast::Cmpop::NotEq => IntPredicate::NE,
|
||||
_ if ty == ctx.primitives.bool => unreachable!(),
|
||||
ast::Cmpop::Lt => if use_unsigned_ops {
|
||||
IntPredicate::ULT
|
||||
} else {
|
||||
IntPredicate::SLT
|
||||
},
|
||||
ast::Cmpop::LtE => if use_unsigned_ops {
|
||||
IntPredicate::ULE
|
||||
} else {
|
||||
IntPredicate::SLE
|
||||
},
|
||||
ast::Cmpop::Gt => if use_unsigned_ops {
|
||||
IntPredicate::UGT
|
||||
} else {
|
||||
IntPredicate::SGT
|
||||
},
|
||||
ast::Cmpop::GtE => if use_unsigned_ops {
|
||||
IntPredicate::UGE
|
||||
} else {
|
||||
IntPredicate::SGE
|
||||
},
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else if ty == ctx.primitives.float {
|
||||
let BasicValueEnum::FloatValue(lhs) = (match generator.gen_expr(ctx, lhs)? {
|
||||
Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?,
|
||||
None => return Ok(None),
|
||||
}) else { unreachable!() };
|
||||
|
||||
let BasicValueEnum::FloatValue(rhs) = (match generator.gen_expr(ctx, rhs)? {
|
||||
Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?,
|
||||
None => return Ok(None),
|
||||
}) else { unreachable!() };
|
||||
|
||||
let op = match op {
|
||||
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
|
||||
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
|
||||
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
|
||||
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
|
||||
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
|
||||
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
|
||||
} else {
|
||||
unimplemented!()
|
||||
};
|
||||
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
|
||||
})?;
|
||||
|
||||
match cmp_val {
|
||||
Some(v) => v.into(),
|
||||
None => return Ok(None),
|
||||
}
|
||||
}
|
||||
ExprKind::IfExp { test, body, orelse } => {
|
||||
let test = match generator.gen_expr(ctx, test)? {
|
||||
|
|
|
@ -346,31 +346,6 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
|||
)
|
||||
}
|
||||
|
||||
fn ndarray_fill_mapping<'ctx, G, MapFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
src: NDArrayValue<'ctx>,
|
||||
dest: NDArrayValue<'ctx>,
|
||||
map_fn: MapFn,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
ndarray_fill_flattened(
|
||||
generator,
|
||||
ctx,
|
||||
dest,
|
||||
|generator, ctx, i| {
|
||||
let elem = unsafe {
|
||||
src.data().get_unchecked(ctx, generator, i, None)
|
||||
};
|
||||
|
||||
map_fn(generator, ctx, elem)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
||||
/// with broadcast-compatible shapes.
|
||||
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
||||
|
@ -667,48 +642,6 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(ndarray)
|
||||
}
|
||||
|
||||
pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
operand: NDArrayValue<'ctx>,
|
||||
map_fn: MapFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator,
|
||||
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let res = res.unwrap_or_else(|| {
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&operand,
|
||||
|_, ctx, v| {
|
||||
Ok(v.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
});
|
||||
|
||||
ndarray_fill_mapping(
|
||||
generator,
|
||||
ctx,
|
||||
operand,
|
||||
res,
|
||||
|generator, ctx, elem| {
|
||||
map_fn(generator, ctx, elem_ty, elem)
|
||||
}
|
||||
)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
||||
///
|
||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
||||
|
@ -728,7 +661,6 @@ pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
|
|||
/// # Panic
|
||||
///
|
||||
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
||||
// TODO: Remove elem_ty from value_fn
|
||||
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
|
|
|
@ -46,15 +46,10 @@ impl PrimitiveDefinitionIds {
|
|||
]
|
||||
}
|
||||
|
||||
/// Returns an iterator over all [`DefinitionId`]s of this instance.
|
||||
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> {
|
||||
self.as_vec().into_iter()
|
||||
}
|
||||
|
||||
/// Returns the primitive with the largest [`DefinitionId`].
|
||||
#[must_use]
|
||||
pub fn max_id(&self) -> DefinitionId {
|
||||
self.iter().max().unwrap()
|
||||
self.as_vec().into_iter().max().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -148,10 +148,8 @@ pub fn impl_binop(
|
|||
});
|
||||
}
|
||||
|
||||
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
|
||||
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
unaryop_name(op).into(),
|
||||
|
@ -170,35 +168,19 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops:
|
|||
|
||||
pub fn impl_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
_store: &PrimitiveStore,
|
||||
store: &PrimitiveStore,
|
||||
ty: Type,
|
||||
other_ty: &[Type],
|
||||
other_ty: Type,
|
||||
ops: &[Cmpop],
|
||||
ret_ty: Option<Type>,
|
||||
) {
|
||||
with_fields(unifier, ty, |unifier, fields| {
|
||||
let (other_ty, other_var_id) = if other_ty.len() == 1 {
|
||||
(other_ty[0], None)
|
||||
} else {
|
||||
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
|
||||
(ty, Some(var_id))
|
||||
};
|
||||
|
||||
let function_vars = if let Some(var_id) = other_var_id {
|
||||
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
|
||||
} else {
|
||||
VarMap::new()
|
||||
};
|
||||
|
||||
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
|
||||
|
||||
for op in ops {
|
||||
fields.insert(
|
||||
comparison_name(op).unwrap().into(),
|
||||
(
|
||||
unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
ret: ret_ty,
|
||||
vars: function_vars.clone(),
|
||||
ret: store.bool,
|
||||
vars: VarMap::new(),
|
||||
args: vec![FuncArg {
|
||||
ty: other_ty,
|
||||
default_value: None,
|
||||
|
@ -292,35 +274,34 @@ pub fn impl_mod(
|
|||
}
|
||||
|
||||
/// `UAdd`, `USub`
|
||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
|
||||
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]);
|
||||
}
|
||||
|
||||
/// `Invert`
|
||||
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
|
||||
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]);
|
||||
}
|
||||
|
||||
/// `Not`
|
||||
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
|
||||
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
|
||||
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]);
|
||||
}
|
||||
|
||||
/// `Lt`, `LtE`, `Gt`, `GtE`
|
||||
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
|
||||
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
|
||||
impl_cmpop(
|
||||
unifier,
|
||||
store,
|
||||
ty,
|
||||
other_ty,
|
||||
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
|
||||
ret_ty,
|
||||
);
|
||||
}
|
||||
|
||||
/// `Eq`, `NotEq`
|
||||
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
|
||||
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
|
||||
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
|
||||
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
|
||||
}
|
||||
|
||||
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
|
||||
|
@ -353,7 +334,7 @@ pub fn typeof_ndarray_broadcast(
|
|||
};
|
||||
|
||||
let res_ndims = left_ty_ndims.into_iter()
|
||||
.cartesian_product(right_ty_ndims)
|
||||
.cartesian_product(right_ty_ndims.into_iter())
|
||||
.map(|(left, right)| {
|
||||
let left_val = u64::try_from(left).unwrap();
|
||||
let right_val = u64::try_from(right).unwrap();
|
||||
|
@ -361,7 +342,7 @@ pub fn typeof_ndarray_broadcast(
|
|||
max(left_val, right_val)
|
||||
})
|
||||
.unique()
|
||||
.map(SymbolValue::U64)
|
||||
.map(|ndim| SymbolValue::U64(ndim))
|
||||
.collect_vec();
|
||||
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
|
||||
|
||||
|
@ -375,9 +356,7 @@ pub fn typeof_ndarray_broadcast(
|
|||
|
||||
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
|
||||
|
||||
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||
Ok(ndarray_ty)
|
||||
} else {
|
||||
if !unifier.unioned(ndarray_ty_dtype, scalar_ty) {
|
||||
let (expected_ty, actual_ty) = if is_left_ndarray {
|
||||
(ndarray_ty_dtype, scalar_ty)
|
||||
} else {
|
||||
|
@ -389,6 +368,8 @@ pub fn typeof_ndarray_broadcast(
|
|||
unifier.stringify(expected_ty),
|
||||
unifier.stringify(actual_ty),
|
||||
))
|
||||
} else {
|
||||
Ok(ndarray_ty)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -401,12 +382,8 @@ pub fn typeof_binop(
|
|||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_right_ndarray = rhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_left_ndarray = lhs.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||
let is_right_ndarray = rhs.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
|
||||
|
||||
Ok(Some(match op {
|
||||
Operator::Add
|
||||
|
@ -445,8 +422,8 @@ pub fn typeof_binop(
|
|||
}
|
||||
|
||||
Operator::LShift
|
||||
| Operator::RShift => lhs,
|
||||
Operator::BitOr
|
||||
| Operator::RShift
|
||||
| Operator::BitOr
|
||||
| Operator::BitXor
|
||||
| Operator::BitAnd => {
|
||||
if unifier.unioned(lhs, rhs) {
|
||||
|
@ -458,50 +435,6 @@ pub fn typeof_binop(
|
|||
}))
|
||||
}
|
||||
|
||||
pub fn typeof_unaryop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
op: &Unaryop,
|
||||
operand: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
|
||||
return Err("The truth value of an array with more than one element is ambiguous".to_string())
|
||||
}
|
||||
|
||||
Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
|
||||
Some(operand)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the return type given a comparison operator and its primitive operands.
|
||||
pub fn typeof_cmpop(
|
||||
unifier: &mut Unifier,
|
||||
primitives: &PrimitiveStore,
|
||||
_op: &Cmpop,
|
||||
lhs: Type,
|
||||
rhs: Type,
|
||||
) -> Result<Option<Type>, String> {
|
||||
let is_left_ndarray = lhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_right_ndarray = rhs
|
||||
.obj_id(unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
Ok(Some(if is_left_ndarray || is_right_ndarray {
|
||||
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
|
||||
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
|
||||
|
||||
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
|
||||
} else if unifier.unioned(lhs, rhs) {
|
||||
primitives.bool
|
||||
} else {
|
||||
return Ok(None)
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
|
||||
let PrimitiveStore {
|
||||
int32: int32_t,
|
||||
|
@ -525,13 +458,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_invert(unifier, store, t, Some(t));
|
||||
impl_not(unifier, store, t, Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_invert(unifier, store, t);
|
||||
impl_not(unifier, store, t);
|
||||
impl_comparison(unifier, store, t, t);
|
||||
impl_eq(unifier, store, t);
|
||||
}
|
||||
for t in [int32_t, int64_t] {
|
||||
impl_sign(unifier, store, t, Some(t));
|
||||
impl_sign(unifier, store, t);
|
||||
}
|
||||
|
||||
/* float ======== */
|
||||
|
@ -542,15 +475,14 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_sign(unifier, store, float_t, Some(float_t));
|
||||
impl_not(unifier, store, float_t, Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_sign(unifier, store, float_t);
|
||||
impl_not(unifier, store, float_t);
|
||||
impl_comparison(unifier, store, float_t, float_t);
|
||||
impl_eq(unifier, store, float_t);
|
||||
|
||||
/* bool ======== */
|
||||
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
|
||||
impl_not(unifier, store, bool_t);
|
||||
impl_eq(unifier, store, bool_t);
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
|
@ -562,8 +494,4 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
|
||||
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
}
|
||||
|
|
|
@ -549,9 +549,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
|
|||
ExprKind::BinOp { left, op, right } => {
|
||||
Some(self.infer_bin_ops(expr.location, left, op, right, false)?)
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => {
|
||||
Some(self.infer_unary_ops(expr.location, op, operand)?)
|
||||
}
|
||||
ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
|
||||
ExprKind::Compare { left, ops, comparators } => {
|
||||
Some(self.infer_compare(left, ops, comparators)?)
|
||||
}
|
||||
|
@ -1226,7 +1224,7 @@ impl<'a> Inferencer<'a> {
|
|||
};
|
||||
|
||||
let ret = if is_aug_assign {
|
||||
// The type of augmented assignment operator should never change
|
||||
// The type of an augmented assignment operator should never change
|
||||
Some(left_ty)
|
||||
} else {
|
||||
typeof_binop(
|
||||
|
@ -1235,7 +1233,7 @@ impl<'a> Inferencer<'a> {
|
|||
op,
|
||||
left_ty,
|
||||
right_ty,
|
||||
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
||||
).map_err(|e| HashSet::from([format!("{} (at {})", e, location)]))?
|
||||
};
|
||||
|
||||
self.build_method_call(
|
||||
|
@ -1249,20 +1247,11 @@ impl<'a> Inferencer<'a> {
|
|||
|
||||
fn infer_unary_ops(
|
||||
&mut self,
|
||||
location: Location,
|
||||
op: &ast::Unaryop,
|
||||
operand: &ast::Expr<Option<Type>>,
|
||||
) -> InferenceResult {
|
||||
let method = unaryop_name(op).into();
|
||||
|
||||
let ret = typeof_unaryop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
op,
|
||||
operand.custom.unwrap(),
|
||||
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
|
||||
|
||||
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], ret)
|
||||
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None)
|
||||
}
|
||||
|
||||
fn infer_compare(
|
||||
|
@ -1271,45 +1260,22 @@ impl<'a> Inferencer<'a> {
|
|||
ops: &[ast::Cmpop],
|
||||
comparators: &[ast::Expr<Option<Type>>],
|
||||
) -> InferenceResult {
|
||||
if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
|
||||
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
|
||||
}
|
||||
|
||||
let boolean = self.primitives.bool;
|
||||
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
|
||||
let method = comparison_name(c)
|
||||
.ok_or_else(|| HashSet::from([
|
||||
"unsupported comparator".to_string()
|
||||
]))?
|
||||
.into();
|
||||
|
||||
let ret = typeof_cmpop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
c,
|
||||
a.custom.unwrap(),
|
||||
b.custom.unwrap(),
|
||||
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
|
||||
|
||||
self.build_method_call(
|
||||
a.location,
|
||||
method,
|
||||
a.custom.unwrap(),
|
||||
vec![b.custom.unwrap()],
|
||||
ret,
|
||||
Some(boolean),
|
||||
)?;
|
||||
}
|
||||
|
||||
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left);
|
||||
let res_rhs = comparators.iter().rev().nth(0).unwrap();
|
||||
let res_op = ops.iter().rev().nth(0).unwrap();
|
||||
|
||||
Ok(typeof_cmpop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
res_op,
|
||||
res_lhs.custom.unwrap(),
|
||||
res_rhs.custom.unwrap(),
|
||||
).unwrap().unwrap())
|
||||
Ok(boolean)
|
||||
}
|
||||
|
||||
/// Infers the type of a subscript expression on an `ndarray`.
|
||||
|
|
|
@ -135,15 +135,10 @@ impl TestEnvironment {
|
|||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
|
|
|
@ -61,20 +61,14 @@ pub enum RecordKey {
|
|||
}
|
||||
|
||||
impl Type {
|
||||
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
|
||||
/// just to get the field `obj_id`.
|
||||
#[must_use]
|
||||
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
|
||||
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
|
||||
Some(*obj_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated = "Prefer using `Type::obj_id` instead to handle non-TObj cases."]
|
||||
// a wrapper function for cleaner code so that we don't need to
|
||||
// write this long pattern matching just to get the field `obj_id`
|
||||
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
|
||||
self.obj_id(unifier).expect("expect a object type")
|
||||
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
|
||||
*obj_id
|
||||
} else {
|
||||
unreachable!("expect a object type")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
@extern
|
||||
def output_bool(x: bool):
|
||||
...
|
||||
|
||||
@extern
|
||||
def output_int32(x: int32):
|
||||
...
|
||||
|
@ -10,20 +6,10 @@ def output_int32(x: int32):
|
|||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
def output_ndarray_bool_2(n: ndarray[bool, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_bool(n[r][c])
|
||||
|
||||
def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]):
|
||||
for i in range(len(n)):
|
||||
output_int32(n[i])
|
||||
|
||||
def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_int32(n[r][c])
|
||||
|
||||
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
||||
for i in range(len(n)):
|
||||
output_float64(n[i])
|
||||
|
@ -422,207 +408,6 @@ def test_ndarray_ipow_broadcast_scalar():
|
|||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_pos():
|
||||
x_int32 = np_full([2, 2], -2)
|
||||
y_int32 = +x_int32
|
||||
|
||||
output_ndarray_int32_2(x_int32)
|
||||
output_ndarray_int32_2(y_int32)
|
||||
|
||||
x_float = np_full([2, 2], -2.0)
|
||||
y_float = +x_float
|
||||
|
||||
output_ndarray_float_2(x_float)
|
||||
output_ndarray_float_2(y_float)
|
||||
|
||||
def test_ndarray_neg():
|
||||
x_int32 = np_full([2, 2], -2)
|
||||
y_int32 = -x_int32
|
||||
|
||||
output_ndarray_int32_2(x_int32)
|
||||
output_ndarray_int32_2(y_int32)
|
||||
|
||||
x_float = np_full([2, 2], 2.0)
|
||||
y_float = -x_float
|
||||
|
||||
output_ndarray_float_2(x_float)
|
||||
output_ndarray_float_2(y_float)
|
||||
|
||||
def test_ndarray_inv():
|
||||
x_int32 = np_full([2, 2], -2)
|
||||
y_int32 = ~x_int32
|
||||
|
||||
output_ndarray_int32_2(x_int32)
|
||||
output_ndarray_int32_2(y_int32)
|
||||
|
||||
def test_ndarray_eq():
|
||||
x = np_identity(2)
|
||||
y = x == np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x == np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 == x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_eq_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x == 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne():
|
||||
x = np_identity(2)
|
||||
y = x != np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x != np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 != x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ne_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x != 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt():
|
||||
x = np_identity(2)
|
||||
y = x < np_full([2, 2], 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x < np_full([2], 1.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 1.0 < x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_lt_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x < 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le():
|
||||
x = np_identity(2)
|
||||
y = x <= np_full([2, 2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x <= np_full([2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.5 <= x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_le_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x <= 0.5
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt():
|
||||
x = np_identity(2)
|
||||
y = x > np_full([2, 2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x > np_full([2], 0.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.0 > x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_gt_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x > 0.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge():
|
||||
x = np_identity(2)
|
||||
y = x >= np_full([2, 2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast():
|
||||
x = np_identity(2)
|
||||
y = x >= np_full([2], 0.5)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = 0.5 >= x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def test_ndarray_ge_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
y = x >= 0.5
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_bool_2(y)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
test_ndarray_empty()
|
||||
|
@ -682,32 +467,5 @@ def run() -> int32:
|
|||
test_ndarray_ipow()
|
||||
test_ndarray_ipow_broadcast()
|
||||
test_ndarray_ipow_broadcast_scalar()
|
||||
test_ndarray_pos()
|
||||
test_ndarray_neg()
|
||||
test_ndarray_inv()
|
||||
test_ndarray_eq()
|
||||
test_ndarray_eq_broadcast()
|
||||
test_ndarray_eq_broadcast_lhs_scalar()
|
||||
test_ndarray_eq_broadcast_rhs_scalar()
|
||||
test_ndarray_ne()
|
||||
test_ndarray_ne_broadcast()
|
||||
test_ndarray_ne_broadcast_lhs_scalar()
|
||||
test_ndarray_ne_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_lt_broadcast()
|
||||
test_ndarray_lt_broadcast_lhs_scalar()
|
||||
test_ndarray_lt_broadcast_rhs_scalar()
|
||||
test_ndarray_lt()
|
||||
test_ndarray_le_broadcast()
|
||||
test_ndarray_le_broadcast_lhs_scalar()
|
||||
test_ndarray_le_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_gt_broadcast()
|
||||
test_ndarray_gt_broadcast_lhs_scalar()
|
||||
test_ndarray_gt_broadcast_rhs_scalar()
|
||||
test_ndarray_gt()
|
||||
test_ndarray_ge_broadcast()
|
||||
test_ndarray_ge_broadcast_lhs_scalar()
|
||||
test_ndarray_ge_broadcast_rhs_scalar()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue