Compare commits

...

12 Commits

Author SHA1 Message Date
David Mak 29ae48faad core: Implement elementwise comparison operators 2024-03-28 01:11:17 +08:00
David Mak d457e6c986 core/codegen/expr: WIP - Split cmpop 2024-03-28 00:24:47 +08:00
David Mak 5450147007 core: Implement elementwise unary operators 2024-03-28 00:24:26 +08:00
David Mak 0537e816a5 core/codegen/expr: WIP - Split unaryop 2024-03-27 12:34:47 +08:00
David Mak 1f0ac6341d core: Implement elementwise binary operators
Including immediate variants of these operators.
2024-03-27 12:10:38 +08:00
David Mak e89bab64c8 core/magic_methods: Add typeof_*op
Used to determine the expected type of the binary operator with
primitive operands.
2024-03-27 12:10:38 +08:00
David Mak 4e4a61d4ac core/toplevel/numpy: Split ndarray type var utilities 2024-03-27 12:10:38 +08:00
David Mak cd9fdf3bad core: Implement calculations for broadcasting ndarrays 2024-03-27 12:10:38 +08:00
David Mak eb56259005 core/type_inferencer: Allow both int32 and isize when indexing ndarray 2024-03-27 12:10:38 +08:00
David Mak cf482cdeac core/magic_methods: Allow unknown return types
These types can be later inferred by the type inferencer.
2024-03-27 12:10:38 +08:00
David Mak d47c18bdcb core/helper: Add PrimitiveDefinitionIds::iter 2024-03-27 12:10:38 +08:00
David Mak 160cfa500a core/typedef: Add Type::obj_id to replace get_obj_id 2024-03-27 10:40:03 +08:00
17 changed files with 2270 additions and 242 deletions

View File

@ -5,7 +5,7 @@ use nac3core::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef,
},
typecheck::{
@ -654,7 +654,7 @@ impl InnerResolver {
}
}
(TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty);
let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty);
let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?;
if len == 0 {
assert!(matches!(

View File

@ -11,7 +11,7 @@ use crate::codegen::{
stmt::gen_for_callback_incrementing,
};
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements.
pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value.

View File

@ -16,6 +16,7 @@ use crate::{
get_llvm_abi_type,
irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy,
stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask,
},
@ -23,7 +24,7 @@ use crate::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty,
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef,
},
typecheck::{
@ -38,6 +39,7 @@ use inkwell::{
types::{AnyType, BasicType, BasicTypeEnum},
values::{BasicValueEnum, CallSiteValue, FunctionValue, IntValue, PointerValue}
};
use inkwell::values::BasicValue;
use itertools::{chain, izip, Itertools, Either};
use nac3parser::ast::{
self, Boolop, Comprehension, Constant, Expr, ExprKind, Location, Operator, StrRef,
@ -1129,6 +1131,76 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i")
);
Ok(Some(res.into()))
} else if ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray || ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray {
let llvm_usize = generator.get_size_type(ctx.ctx);
let is_ndarray1 = ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray2 = ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
(left_val.as_ptr_value().into(), false),
(right_val, false),
|generator, ctx, elem_ty, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
op,
(&Some(elem_ty), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { ty1 } else { ty2 },
);
let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize,
None,
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype,
if is_aug_assign { Some(ndarray_val) } else { None },
(left_val, !is_ndarray1),
(right_val, !is_ndarray2),
|generator, ctx, elem_ty, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(elem_ty), lhs),
op,
(&Some(elem_ty), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
Ok(Some(res.as_ptr_value().into()))
}
} else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
@ -1220,6 +1292,294 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
)
}
pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
op: &ast::Unaryop,
operand: (&Option<Type>, BasicValueEnum<'ctx>),
) -> Result<Option<ValueEnum<'ctx>>, String> {
let (ty, val) = operand;
let ty = ctx.unifier.get_representative(ty.unwrap());
Ok(Some(if ty == ctx.primitives.bool {
let val = val.into_int_value();
match op {
ast::Unaryop::Invert | ast::Unaryop::Not => {
ctx.builder.build_not(val, "not").map(Into::into).unwrap()
}
_ => val.into(),
}
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
let val = val.into_int_value();
match op {
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(),
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(),
ast::Unaryop::UAdd => val.into(),
}
} else if ty == ctx.primitives.float {
let val = val.into_float_value();
match op {
ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(),
ast::Unaryop::Not => ctx
.builder
.build_float_compare(
inkwell::FloatPredicate::OEQ,
val,
val.get_type().const_zero(),
"not",
)
.map(Into::into)
.unwrap(),
_ => val.into(),
}
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
let val = NDArrayValue::from_ptr_val(
val.into_pointer_value(),
llvm_usize,
None,
);
let res = numpy::ndarray_elementwise_unaryop_impl(
generator,
ctx,
ndarray_dtype,
None,
val,
|generator, ctx, elem_ty, val| {
gen_unaryop_expr_with_values(
generator,
ctx,
op,
(&Some(elem_ty), val)
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)
},
)?;
res.as_ptr_value().into()
} else {
unimplemented!()
}))
}
pub fn gen_unaryop_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
op: &ast::Unaryop,
operand: &Expr<Option<Type>>,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let val = if let Some(v) = generator.gen_expr(ctx, operand)? {
v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())?
} else {
return Ok(None)
};
gen_unaryop_expr_with_values(generator, ctx, op, (&operand.custom, val))
}
pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
left: (Option<Type>, BasicValueEnum<'ctx>),
ops: &[ast::Cmpop],
comparators: &[(Option<Type>, BasicValueEnum<'ctx>)],
) -> Result<Option<ValueEnum<'ctx>>, String> {
debug_assert_eq!(comparators.len(), ops.len());
if comparators.len() == 1 {
let left_ty = ctx.unifier.get_representative(left.0.unwrap());
let right_ty = ctx.unifier.get_representative(comparators[0].0.unwrap());
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let (Some(left_ty), lhs) = left else { unreachable!() };
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
let op = ops[0].clone();
let is_ndarray1 = left_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_ndarray2 = right_ty.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray;
return if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val(
lhs.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(left_val.as_ptr_value().into(), false),
(rhs, false),
|generator, ctx, elem_ty, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype1), lhs),
&[op.clone()],
&[(Some(ndarray_dtype2), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { left_ty } else { right_ty },
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ctx.primitives.bool,
None,
(lhs, !is_ndarray1),
(rhs, !is_ndarray2),
|generator, ctx, elem_ty, (lhs, rhs)| {
let val = gen_cmpop_expr_with_values(
generator,
ctx,
(Some(ndarray_dtype), lhs),
&[op.clone()],
&[(Some(ndarray_dtype), rhs)],
)?.unwrap().to_basic_value_enum(ctx, generator, elem_ty)?;
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
},
)?;
Ok(Some(res.as_ptr_value().into()))
}
}
}
let cmp_val = izip!(chain(once(&left), comparators.iter()), comparators.iter(), ops.iter(),)
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
let (left_ty, lhs) = lhs;
let (right_ty, rhs) = rhs;
let left_ty = ctx.unifier.get_representative(left_ty.unwrap());
let right_ty = ctx.unifier.get_representative(right_ty.unwrap());
let current =
if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool]
.contains(&left_ty)
{
assert!(ctx.unifier.unioned(left_ty, right_ty));
let use_unsigned_ops = [
ctx.primitives.uint32,
ctx.primitives.uint64,
].contains(&left_ty);
let lhs = lhs.into_int_value();
let rhs = rhs.into_int_value();
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
ast::Cmpop::NotEq => IntPredicate::NE,
_ if left_ty == ctx.primitives.bool => unreachable!(),
ast::Cmpop::Lt => if use_unsigned_ops {
IntPredicate::ULT
} else {
IntPredicate::SLT
},
ast::Cmpop::LtE => if use_unsigned_ops {
IntPredicate::ULE
} else {
IntPredicate::SLE
},
ast::Cmpop::Gt => if use_unsigned_ops {
IntPredicate::UGT
} else {
IntPredicate::SGT
},
ast::Cmpop::GtE => if use_unsigned_ops {
IntPredicate::UGE
} else {
IntPredicate::SGE
},
_ => unreachable!(),
};
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
} else if left_ty == ctx.primitives.float {
assert!(ctx.unifier.unioned(left_ty, right_ty));
let lhs = lhs.into_float_value();
let rhs = rhs.into_float_value();
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
} else {
unimplemented!()
};
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
})?;
Ok(Some(match cmp_val {
Some(v) => v.into(),
None => return Ok(None),
}))
}
pub fn gen_cmpop_expr<'ctx, G: CodeGenerator>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
left: &Expr<Option<Type>>,
ops: &[ast::Cmpop],
comparators: &[Expr<Option<Type>>],
) -> Result<Option<ValueEnum<'ctx>>, String> {
let left_val = if let Some(v) = generator.gen_expr(ctx, left)? {
v.to_basic_value_enum(ctx, generator, left.custom.unwrap())?
} else {
return Ok(None)
};
let comparators = {
let mut new_comparators = Vec::new();
new_comparators.reserve(comparators.len());
for cmptor in comparators {
if let Some(v) = generator.gen_expr(ctx, cmptor)? {
new_comparators.push((cmptor.custom, v.to_basic_value_enum(ctx, generator, cmptor.custom.unwrap())?))
} else {
return Ok(None)
}
}
new_comparators
};
gen_cmpop_expr_with_values(
generator,
ctx,
(left.custom, left_val),
ops,
comparators.as_slice(),
)
}
/// Generates code for a subscript expression on an `ndarray`.
///
/// * `ty` - The `Type` of the `NDArray` elements.
@ -1596,130 +1956,10 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
return gen_binop_expr(generator, ctx, left, op, right, expr.location, false);
}
ExprKind::UnaryOp { op, operand } => {
let ty = ctx.unifier.get_representative(operand.custom.unwrap());
let val = if let Some(v) = generator.gen_expr(ctx, operand)? {
v.to_basic_value_enum(ctx, generator, operand.custom.unwrap())?
} else {
return Ok(None)
};
if ty == ctx.primitives.bool {
let val = val.into_int_value();
match op {
ast::Unaryop::Invert | ast::Unaryop::Not => {
ctx.builder.build_not(val, "not").map(Into::into).unwrap()
}
_ => val.into(),
}
} else if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty) {
let val = val.into_int_value();
match op {
ast::Unaryop::USub => ctx.builder.build_int_neg(val, "neg").map(Into::into).unwrap(),
ast::Unaryop::Invert => ctx.builder.build_not(val, "not").map(Into::into).unwrap(),
ast::Unaryop::Not => ctx.builder.build_xor(val, val.get_type().const_all_ones(), "not").map(Into::into).unwrap(),
ast::Unaryop::UAdd => val.into(),
}
} else if ty == ctx.primitives.float {
let val = val.into_float_value();
match op {
ast::Unaryop::USub => ctx.builder.build_float_neg(val, "neg").map(Into::into).unwrap(),
ast::Unaryop::Not => ctx
.builder
.build_float_compare(
inkwell::FloatPredicate::OEQ,
val,
val.get_type().const_zero(),
"not",
)
.map(Into::into)
.unwrap(),
_ => val.into(),
}
} else {
unimplemented!()
}
return gen_unaryop_expr(generator, ctx, op, operand)
}
ExprKind::Compare { left, ops, comparators } => {
let cmp_val = izip!(chain(once(left.as_ref()), comparators.iter()), comparators.iter(), ops.iter(),)
.fold(Ok(None), |prev: Result<Option<_>, String>, (lhs, rhs, op)| {
let ty = ctx.unifier.get_representative(lhs.custom.unwrap());
let current =
if [ctx.primitives.int32, ctx.primitives.int64, ctx.primitives.uint32, ctx.primitives.uint64, ctx.primitives.bool]
.contains(&ty)
{
let use_unsigned_ops = [
ctx.primitives.uint32,
ctx.primitives.uint64,
].contains(&ty);
let BasicValueEnum::IntValue(lhs) = (match generator.gen_expr(ctx, lhs)? {
Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?,
None => return Ok(None),
}) else { unreachable!() };
let BasicValueEnum::IntValue(rhs) = (match generator.gen_expr(ctx, rhs)? {
Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?,
None => return Ok(None),
}) else { unreachable!() };
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => IntPredicate::EQ,
ast::Cmpop::NotEq => IntPredicate::NE,
_ if ty == ctx.primitives.bool => unreachable!(),
ast::Cmpop::Lt => if use_unsigned_ops {
IntPredicate::ULT
} else {
IntPredicate::SLT
},
ast::Cmpop::LtE => if use_unsigned_ops {
IntPredicate::ULE
} else {
IntPredicate::SLE
},
ast::Cmpop::Gt => if use_unsigned_ops {
IntPredicate::UGT
} else {
IntPredicate::SGT
},
ast::Cmpop::GtE => if use_unsigned_ops {
IntPredicate::UGE
} else {
IntPredicate::SGE
},
_ => unreachable!(),
};
ctx.builder.build_int_compare(op, lhs, rhs, "cmp").unwrap()
} else if ty == ctx.primitives.float {
let BasicValueEnum::FloatValue(lhs) = (match generator.gen_expr(ctx, lhs)? {
Some(v) => v.to_basic_value_enum(ctx, generator, lhs.custom.unwrap())?,
None => return Ok(None),
}) else { unreachable!() };
let BasicValueEnum::FloatValue(rhs) = (match generator.gen_expr(ctx, rhs)? {
Some(v) => v.to_basic_value_enum(ctx, generator, rhs.custom.unwrap())?,
None => return Ok(None),
}) else { unreachable!() };
let op = match op {
ast::Cmpop::Eq | ast::Cmpop::Is => inkwell::FloatPredicate::OEQ,
ast::Cmpop::NotEq => inkwell::FloatPredicate::ONE,
ast::Cmpop::Lt => inkwell::FloatPredicate::OLT,
ast::Cmpop::LtE => inkwell::FloatPredicate::OLE,
ast::Cmpop::Gt => inkwell::FloatPredicate::OGT,
ast::Cmpop::GtE => inkwell::FloatPredicate::OGE,
_ => unreachable!(),
};
ctx.builder.build_float_compare(op, lhs, rhs, "cmp").unwrap()
} else {
unimplemented!()
};
Ok(prev?.map(|v| ctx.builder.build_and(v, current, "cmp").unwrap()).or(Some(current)))
})?;
match cmp_val {
Some(v) => v.into(),
None => return Ok(None),
}
return gen_cmpop_expr(generator, ctx, left, &ops, &comparators)
}
ExprKind::IfExp { test, body, orelse } => {
let test = match generator.gen_expr(ctx, test)? {

View File

@ -8,6 +8,8 @@ typedef unsigned _BitInt(64) uint64_t;
# define MAX(a, b) (a > b ? a : b)
# define MIN(a, b) (a > b ? b : a)
# define NULL ((void *) 0)
// adapted from GNU Scientific Library: https://git.savannah.gnu.org/cgit/gsl.git/tree/sys/pow_int.c
// need to make sure `exp >= 0` before calling this function
#define DEF_INT_EXP(T) T __nac3_int_exp_##T( \
@ -293,3 +295,87 @@ uint64_t __nac3_ndarray_flatten_index64(
}
return idx;
}
void __nac3_ndarray_calc_broadcast(
const uint32_t *lhs_dims,
uint32_t lhs_ndims,
const uint32_t *rhs_dims,
uint32_t rhs_ndims,
uint32_t *out_dims
) {
uint32_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint32_t i = 0; i < max_ndims; ++i) {
uint32_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint32_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint32_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast64(
const uint64_t *lhs_dims,
uint64_t lhs_ndims,
const uint64_t *rhs_dims,
uint64_t rhs_ndims,
uint64_t *out_dims
) {
uint64_t max_ndims = lhs_ndims > rhs_ndims ? lhs_ndims : rhs_ndims;
for (uint64_t i = 0; i < max_ndims; ++i) {
uint64_t *lhs_dim_sz = i < lhs_ndims ? &lhs_dims[lhs_ndims - i - 1] : NULL;
uint64_t *rhs_dim_sz = i < rhs_ndims ? &rhs_dims[rhs_ndims - i - 1] : NULL;
uint64_t *out_dim = &out_dims[max_ndims - i - 1];
if (lhs_dim_sz == NULL) {
*out_dim = *rhs_dim_sz;
} else if (rhs_dim_sz == NULL) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == 1) {
*out_dim = *rhs_dim_sz;
} else if (*rhs_dim_sz == 1) {
*out_dim = *lhs_dim_sz;
} else if (*lhs_dim_sz == *rhs_dim_sz) {
*out_dim = *lhs_dim_sz;
} else {
__builtin_unreachable();
}
}
}
void __nac3_ndarray_calc_broadcast_idx(
const uint32_t *src_dims,
uint32_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint32_t i = 0; i < src_ndims; ++i) {
uint32_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : in_idx[src_i];
}
}
void __nac3_ndarray_calc_broadcast_idx64(
const uint64_t *src_dims,
uint64_t src_ndims,
const uint32_t *in_idx,
uint32_t *out_idx
) {
for (uint64_t i = 0; i < src_ndims; ++i) {
uint64_t src_i = src_ndims - i - 1;
out_idx[src_i] = src_dims[src_i] == 1 ? 0 : (uint32_t) in_idx[src_i];
}
}

View File

@ -8,9 +8,11 @@ use super::{
ListValue,
NDArrayValue,
TypedArrayLikeAdapter,
UntypedArrayLikeAccessor,
},
CodeGenContext,
CodeGenerator,
llvm_intrinsics,
};
use inkwell::{
attributes::{Attribute, AttributeLoc},
@ -783,4 +785,153 @@ pub fn call_ndarray_flatten_index<'ctx, G, Index>(
ndarray,
indices,
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast`. Returns a tuple containing the number of
/// dimension and size of each dimension of the resultant `ndarray`.
pub fn call_ndarray_calc_broadcast<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
lhs: NDArrayValue<'ctx>,
rhs: NDArrayValue<'ctx>,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast",
64 => "__nac3_ndarray_calc_broadcast64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
llvm_usize.into(),
llvm_pusize.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_ndims = rhs.load_ndims(ctx);
let max_ndims = llvm_intrinsics::call_int_umax(ctx, lhs_ndims, rhs_ndims, None);
// TODO: Generate assertion checks for whether each dimension is compatible
// gen_for_callback_incrementing(
// generator,
// ctx,
// llvm_usize.const_zero(),
// (max_ndims, false),
// |generator, ctx, idx| {
// let lhs_dim_sz =
//
// let lhs_elem = lhs.get_dims().get(ctx, generator, idx, None);
// let rhs_elem = rhs.get_dims().get(ctx, generator, idx, None);
//
//
// },
// llvm_usize.const_int(1, false),
// ).unwrap();
let lhs_dims = lhs.dim_sizes().base_ptr(ctx, generator);
let lhs_ndims = lhs.load_ndims(ctx);
let rhs_dims = rhs.dim_sizes().base_ptr(ctx, generator);
let rhs_ndims = rhs.load_ndims(ctx);
let out_dims = ctx.builder.build_array_alloca(llvm_usize, max_ndims, "").unwrap();
let out_dims = ArraySliceValue::from_ptr_val(out_dims, max_ndims, None);
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
lhs_dims.into(),
lhs_ndims.into(),
rhs_dims.into(),
rhs_ndims.into(),
out_dims.base_ptr(ctx, generator).into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
out_dims,
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}
/// Generates a call to `__nac3_ndarray_calc_broadcast_idx`. Returns an [`ArrayAllocaValue`]
/// containing the indices used for accessing `array` corresponding to the index of the broadcasted
/// array `broadcast_idx`.
pub fn call_ndarray_calc_broadcast_index<'ctx, G: CodeGenerator + ?Sized, BroadcastIdx: UntypedArrayLikeAccessor<'ctx>>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
array: NDArrayValue<'ctx>,
broadcast_idx: &BroadcastIdx,
) -> TypedArrayLikeAdapter<'ctx, IntValue<'ctx>> {
let llvm_i32 = ctx.ctx.i32_type();
let llvm_usize = generator.get_size_type(ctx.ctx);
let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default());
let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default());
let ndarray_calc_broadcast_fn_name = match llvm_usize.get_bit_width() {
32 => "__nac3_ndarray_calc_broadcast_idx",
64 => "__nac3_ndarray_calc_broadcast_idx64",
bw => unreachable!("Unsupported size type bit width: {}", bw)
};
let ndarray_calc_broadcast_fn = ctx.module.get_function(ndarray_calc_broadcast_fn_name).unwrap_or_else(|| {
let fn_type = llvm_usize.fn_type(
&[
llvm_pusize.into(),
llvm_usize.into(),
llvm_pi32.into(),
llvm_pi32.into(),
],
false,
);
ctx.module.add_function(ndarray_calc_broadcast_fn_name, fn_type, None)
});
// TODO: Assertions
let broadcast_size = broadcast_idx.size(ctx, generator);
let out_idx = ctx.builder.build_array_alloca(llvm_i32, broadcast_size, "").unwrap();
let array_dims = array.dim_sizes().base_ptr(ctx, generator);
let array_ndims = array.load_ndims(ctx);
let broadcast_idx_ptr = unsafe {
broadcast_idx.ptr_offset_unchecked(
ctx,
generator,
llvm_usize.const_zero(),
None
)
};
ctx.builder
.build_call(
ndarray_calc_broadcast_fn,
&[
array_dims.into(),
array_ndims.into(),
broadcast_idx_ptr.into(),
out_idx.into(),
],
"",
)
.unwrap();
TypedArrayLikeAdapter::from(
ArraySliceValue::from_ptr_val(out_idx, broadcast_size, None),
Box::new(|_, v| v.into_int_value()),
Box::new(|_, v| v.into()),
)
}

View File

@ -2,7 +2,7 @@ use crate::{
symbol_resolver::{StaticValue, SymbolResolver},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars,
numpy::unpack_ndarray_var_tys,
TopLevelContext,
TopLevelDef,
},
@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>(
TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let llvm_usize = generator.get_size_type(ctx);
let (dtype, _) = unpack_ndarray_tvars(unifier, ty);
let (dtype, _) = unpack_ndarray_var_tys(unifier, ty);
let element_type = get_llvm_type(
ctx,
module,

View File

@ -18,6 +18,8 @@ use crate::{
CodeGenContext,
CodeGenerator,
irrt::{
call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices,
call_ndarray_calc_size,
},
@ -27,7 +29,7 @@ use crate::{
symbol_resolver::ValueEnum,
toplevel::{
DefinitionId,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
},
typecheck::typedef::{FunSignature, Type},
};
@ -344,6 +346,92 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
)
}
fn ndarray_fill_mapping<'ctx, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
src: NDArrayValue<'ctx>,
dest: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<(), String>
where
G: CodeGenerator + ?Sized,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
ndarray_fill_flattened(
generator,
ctx,
dest,
|generator, ctx, i| {
let elem = unsafe {
src.data().get_unchecked(ctx, generator, i, None)
};
map_fn(generator, ctx, elem)
},
)
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: NDArrayValue<'ctx>,
lhs: (BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs;
assert!(!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type());
ndarray_fill_indexed(
generator,
ctx,
res,
|generator, ctx, idx| {
let lhs_elem = if lhs_scalar {
lhs_val
} else {
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx);
unsafe {
lhs.data().get_unchecked(ctx, generator, lhs_idx, None)
}
};
let rhs_elem = if rhs_scalar {
rhs_val
} else {
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx);
unsafe {
rhs.data().get_unchecked(ctx, generator, rhs_idx, None)
}
};
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
value_fn(generator, ctx, elem_ty, (lhs_elem, rhs_elem))
},
)?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
///
/// * `elem_ty` - The element type of the `NDArray`.
@ -579,6 +667,151 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray)
}
pub fn ndarray_elementwise_unaryop_impl<'ctx, G, MapFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
operand: NDArrayValue<'ctx>,
map_fn: MapFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
MapFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, BasicValueEnum<'ctx>) -> Result<BasicValueEnum<'ctx>, String>,
{
let res = res.unwrap_or_else(|| {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&operand,
|_, ctx, v| {
Ok(v.load_ndims(ctx))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
});
ndarray_fill_mapping(
generator,
ctx,
operand,
res,
|generator, ctx, elem| {
map_fn(generator, ctx, elem_ty, elem)
}
)?;
Ok(res)
}
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
/// `value_fn` arguments tuple for all output elements.
///
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
/// broadcast to all elements).
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
///
/// # Panic
///
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
// TODO: Remove elem_ty from value_fn
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: (BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs;
assert!(!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type());
let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
|generator, ctx, v| {
Ok(v.size(ctx, generator))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
} else {
let ndarray = NDArrayValue::from_ptr_val(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_usize,
None,
);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray,
|_, ctx, v| {
Ok(v.load_ndims(ctx))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
}
});
ndarray_broadcast_fill(
generator,
ctx,
elem_ty,
ndarray,
lhs,
rhs,
|generator, ctx, elem_ty, elems| {
value_fn(generator, ctx, elem_ty, elems)
},
)?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>,
@ -765,7 +998,7 @@ pub fn gen_ndarray_copy<'ctx>(
let llvm_usize = generator.get_size_type(context.ctx);
let this_ty = obj.as_ref().unwrap().0;
let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty);
let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty);
let this_arg = obj
.as_ref()
.unwrap()

View File

@ -13,7 +13,7 @@ use crate::{
toplevel::{
DefinitionId,
helper::PRIMITIVE_DEF_IDS,
numpy::unpack_ndarray_tvars,
numpy::unpack_ndarray_var_tys,
TopLevelDef,
},
typecheck::typedef::{FunSignature, Type, TypeEnum},
@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) {
TypeEnum::TList { ty } => *ty,
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0
unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0
}
_ => unreachable!(),
};
@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// body(x);
/// }
/// ```
///
///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum

View File

@ -170,13 +170,13 @@ impl SymbolValue {
/// Returns the [`TypeAnnotation`] representing the data type of this value.
pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation {
match self {
SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool),
SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float),
SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32),
SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64),
SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32),
SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64),
SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str),
SymbolValue::Bool(..)
| SymbolValue::Double(..)
| SymbolValue::I32(..)
| SymbolValue::I64(..)
| SymbolValue::U32(..)
| SymbolValue::U64(..)
| SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)),
SymbolValue::Tuple(vs) => {
let vs_tys = vs
.iter()
@ -230,6 +230,36 @@ impl Display for SymbolValue {
}
}
impl TryFrom<SymbolValue> for u64 {
type Error = ();
/// TODO
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(v as u64),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()),
SymbolValue::U32(v) => Ok(v as u64),
SymbolValue::U64(v) => Ok(v),
_ => Err(()),
}
}
}
impl TryFrom<SymbolValue> for i128 {
type Error = ();
/// TODO
fn try_from(value: SymbolValue) -> Result<Self, Self::Error> {
match value {
SymbolValue::I32(v) => Ok(v as i128),
SymbolValue::I64(v) => Ok(v as i128),
SymbolValue::U32(v) => Ok(v as i128),
SymbolValue::U64(v) => Ok(v as i128),
_ => Err(()),
}
}
}
pub trait StaticValue {
/// Returns a unique identifier for this value.
fn get_unique_identifier(&self) -> u64;

View File

@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some("N".into()),
None,
);
let size_t = primitives.0.usize();
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
let exception_fields = vec![
("__name__".into(), int32, true),
@ -345,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.nth(1)
.map(|(var_id, ty)| (*ty, *var_id))
.unwrap();
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
size_t,
Some("ndarray_ndims".into()),
None,
);
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap();
let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap();
let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap();
let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap();
let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap();
let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap();
let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap();
let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap();
let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap();
let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap();
let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap();
let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap();
let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap();
let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
@ -524,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
methods: vec![
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)),
("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)),
("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)),
("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)),
("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)),
("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)),
("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)),
("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)),
("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)),
("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)),
("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)),
("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)),
("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)),
("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)),
],
ancestors: Vec::default(),
constructor: None,
@ -562,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__add__".into(),
simple_name: "__add__".into(),
signature: ndarray_add_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__sub__".into(),
simple_name: "__sub__".into(),
signature: ndarray_sub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mul__".into(),
simple_name: "__mul__".into(),
signature: ndarray_mul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__truediv__".into(),
simple_name: "__truediv__".into(),
signature: ndarray_truediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__floordiv__".into(),
simple_name: "__floordiv__".into(),
signature: ndarray_floordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mod__".into(),
simple_name: "__mod__".into(),
signature: ndarray_mod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__pow__".into(),
simple_name: "__pow__".into(),
signature: ndarray_pow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__iadd__".into(),
simple_name: "__iadd__".into(),
signature: ndarray_iadd_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__isub__".into(),
simple_name: "__isub__".into(),
signature: ndarray_isub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imul__".into(),
simple_name: "__imul__".into(),
signature: ndarray_imul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__itruediv__".into(),
simple_name: "__itruediv__".into(),
signature: ndarray_itruediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ifloordiv__".into(),
simple_name: "__ifloordiv__".into(),
signature: ndarray_ifloordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imod__".into(),
simple_name: "__imod__".into(),
signature: ndarray_imod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ipow__".into(),
simple_name: "__ipow__".into(),
signature: ndarray_ipow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "int32".into(),
simple_name: "int32".into(),

View File

@ -1,6 +1,7 @@
use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::subst_ndarray_tvars;
use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location};
@ -45,10 +46,15 @@ impl PrimitiveDefinitionIds {
]
}
/// Returns an iterator over all [`DefinitionId`]s of this instance.
pub fn iter(&self) -> impl Iterator<Item=DefinitionId> {
self.as_vec().into_iter()
}
/// Returns the primitive with the largest [`DefinitionId`].
#[must_use]
pub fn max_id(&self) -> DefinitionId {
self.as_vec().into_iter().max().unwrap()
self.iter().max().unwrap()
}
}
@ -226,11 +232,57 @@ impl TopLevelComposer {
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_binop_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_binop_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_truediv_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_truediv_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)),
("__add__".into(), (ndarray_binop_fun_ty, true)),
("__sub__".into(), (ndarray_binop_fun_ty, true)),
("__mul__".into(), (ndarray_binop_fun_ty, true)),
("__truediv__".into(), (ndarray_truediv_fun_ty, true)),
("__floordiv__".into(), (ndarray_binop_fun_ty, true)),
("__mod__".into(), (ndarray_binop_fun_ty, true)),
("__pow__".into(), (ndarray_binop_fun_ty, true)),
("__iadd__".into(), (ndarray_binop_fun_ty, true)),
("__isub__".into(), (ndarray_binop_fun_ty, true)),
("__imul__".into(), (ndarray_binop_fun_ty, true)),
("__itruediv__".into(), (ndarray_truediv_fun_ty, true)),
("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)),
("__imod__".into(), (ndarray_binop_fun_ty, true)),
("__ipow__".into(), (ndarray_binop_fun_ty, true)),
]),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -238,7 +290,16 @@ impl TopLevelComposer {
]),
});
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap();
unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap();
let primitives = PrimitiveStore {
int32,

View File

@ -19,13 +19,30 @@ pub fn make_ndarray_ty(
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let ndarray = primitives.ndarray;
subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims)
}
/// Substitutes type variables in `ndarray`.
///
/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not
/// specialized.
pub fn subst_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
dtype: Option<Type>,
ndims: Option<Type>,
) -> Type {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray);
if dtype.is_none() && ndims.is_none() {
return ndarray
}
let tvar_ids = params.iter()
.map(|(obj_id, _)| *obj_id)
.collect_vec();
@ -42,12 +59,10 @@ pub fn make_ndarray_ty(
unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray)
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_tvars(
fn unpack_ndarray_tvars(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
) -> Vec<(u32, Type)> {
let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else {
panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray))
};
@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars(
params.iter()
.sorted_by_key(|(obj_id, _)| *obj_id)
.map(|(_, ty)| *ty)
.map(|(var_id, ty)| (*var_id, *ty))
.collect_vec()
}
/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds
/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray`
/// respectively.
pub fn unpack_ndarray_var_ids(
unifier: &mut Unifier,
ndarray: Type,
) -> (u32, u32) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.0)
.collect_tuple()
.unwrap()
}
/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to
/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively.
pub fn unpack_ndarray_var_tys(
unifier: &mut Unifier,
ndarray: Type,
) -> (Type, Type) {
unpack_ndarray_tvars(unifier, ndarray)
.into_iter()
.map(|v| v.1)
.collect_tuple()
.unwrap()
}

View File

@ -1,3 +1,7 @@
use std::cmp::max;
use crate::symbol_resolver::SymbolValue;
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
use crate::toplevel::numpy::{make_ndarray_ty, unpack_ndarray_var_tys};
use crate::typecheck::{
type_inferencer::*,
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap},
@ -6,6 +10,7 @@ use nac3parser::ast::StrRef;
use nac3parser::ast::{Cmpop, Operator, Unaryop};
use std::collections::HashMap;
use std::rc::Rc;
use itertools::Itertools;
#[must_use]
pub fn binop_name(op: &Operator) -> &'static str {
@ -90,7 +95,7 @@ pub fn impl_binop(
_store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ret_ty: Option<Type>,
ops: &[Operator],
) {
with_fields(unifier, ty, |unifier, fields| {
@ -107,6 +112,8 @@ pub fn impl_binop(
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(binop_name(op).into(), {
(
@ -141,8 +148,10 @@ pub fn impl_binop(
});
}
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryop]) {
pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Option<Type>, ops: &[Unaryop]) {
with_fields(unifier, ty, |unifier, fields| {
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(
unaryop_name(op).into(),
@ -161,19 +170,35 @@ pub fn impl_unaryop(unifier: &mut Unifier, ty: Type, ret_ty: Type, ops: &[Unaryo
pub fn impl_cmpop(
unifier: &mut Unifier,
store: &PrimitiveStore,
_store: &PrimitiveStore,
ty: Type,
other_ty: Type,
other_ty: &[Type],
ops: &[Cmpop],
ret_ty: Option<Type>,
) {
with_fields(unifier, ty, |unifier, fields| {
let (other_ty, other_var_id) = if other_ty.len() == 1 {
(other_ty[0], None)
} else {
let (ty, var_id) = unifier.get_fresh_var_with_range(other_ty, Some("N".into()), None);
(ty, Some(var_id))
};
let function_vars = if let Some(var_id) = other_var_id {
vec![(var_id, other_ty)].into_iter().collect::<VarMap>()
} else {
VarMap::new()
};
let ret_ty = ret_ty.unwrap_or_else(|| unifier.get_fresh_var(None, None).0);
for op in ops {
fields.insert(
comparison_name(op).unwrap().into(),
(
unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.bool,
vars: VarMap::new(),
ret: ret_ty,
vars: function_vars.clone(),
args: vec![FuncArg {
ty: other_ty,
default_value: None,
@ -193,7 +218,7 @@ pub fn impl_basic_arithmetic(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ret_ty: Option<Type>,
) {
impl_binop(
unifier,
@ -211,7 +236,7 @@ pub fn impl_pow(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Pow]);
}
@ -223,19 +248,25 @@ pub fn impl_bitwise_arithmetic(unifier: &mut Unifier, store: &PrimitiveStore, ty
store,
ty,
&[ty],
ty,
Some(ty),
&[Operator::BitAnd, Operator::BitOr, Operator::BitXor],
);
}
/// `LShift`, `RShift`
pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_binop(unifier, store, ty, &[store.int32, store.uint32], ty, &[Operator::LShift, Operator::RShift]);
impl_binop(unifier, store, ty, &[store.int32, store.uint32], Some(ty), &[Operator::LShift, Operator::RShift]);
}
/// `Div`
pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) {
impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]);
pub fn impl_div(
unifier: &mut Unifier,
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]);
}
/// `FloorDiv`
@ -244,7 +275,7 @@ pub fn impl_floordiv(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::FloorDiv]);
}
@ -255,40 +286,220 @@ pub fn impl_mod(
store: &PrimitiveStore,
ty: Type,
other_ty: &[Type],
ret_ty: Type,
ret_ty: Option<Type>,
) {
impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Mod]);
}
/// `UAdd`, `USub`
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::UAdd, Unaryop::USub]);
pub fn impl_sign(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::UAdd, Unaryop::USub]);
}
/// `Invert`
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, ty, &[Unaryop::Invert]);
pub fn impl_invert(unifier: &mut Unifier, _store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Invert]);
}
/// `Not`
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_unaryop(unifier, ty, store.bool, &[Unaryop::Not]);
pub fn impl_not(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, ret_ty: Option<Type>) {
impl_unaryop(unifier, ty, ret_ty, &[Unaryop::Not]);
}
/// `Lt`, `LtE`, `Gt`, `GtE`
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: Type) {
pub fn impl_comparison(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
impl_cmpop(
unifier,
store,
ty,
other_ty,
&[Cmpop::Lt, Cmpop::Gt, Cmpop::LtE, Cmpop::GtE],
ret_ty,
);
}
/// `Eq`, `NotEq`
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) {
impl_cmpop(unifier, store, ty, ty, &[Cmpop::Eq, Cmpop::NotEq]);
pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type], ret_ty: Option<Type>) {
impl_cmpop(unifier, store, ty, other_ty, &[Cmpop::Eq, Cmpop::NotEq], ret_ty);
}
/// Returns the expected return type of binary operations with at least one `ndarray` operand.
pub fn typeof_ndarray_broadcast(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
left: Type,
right: Type,
) -> Result<Type, String> {
let is_left_ndarray = left.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
let is_right_ndarray = right.get_obj_id(unifier) == PRIMITIVE_DEF_IDS.ndarray;
assert!(is_left_ndarray || is_right_ndarray);
if is_left_ndarray && is_right_ndarray {
// Perform broadcasting on two ndarray operands.
let (left_ty_dtype, left_ty_ndims) = unpack_ndarray_var_tys(unifier, left);
let (right_ty_dtype, right_ty_ndims) = unpack_ndarray_var_tys(unifier, right);
assert!(unifier.unioned(left_ty_dtype, right_ty_dtype));
let left_ty_ndims = match &*unifier.get_ty_immutable(left_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let right_ty_ndims = match &*unifier.get_ty_immutable(right_ty_ndims) {
TypeEnum::TLiteral { values, .. } => values.clone(),
_ => unreachable!(),
};
let res_ndims = left_ty_ndims.into_iter()
.cartesian_product(right_ty_ndims)
.map(|(left, right)| {
let left_val = u64::try_from(left).unwrap();
let right_val = u64::try_from(right).unwrap();
max(left_val, right_val)
})
.unique()
.map(SymbolValue::U64)
.collect_vec();
let res_ndims = unifier.get_fresh_literal(res_ndims, None);
Ok(make_ndarray_ty(unifier, primitives, Some(left_ty_dtype), Some(res_ndims)))
} else {
let (ndarray_ty, scalar_ty) = if is_left_ndarray {
(left, right)
} else {
(right, left)
};
let (ndarray_ty_dtype, _) = unpack_ndarray_var_tys(unifier, ndarray_ty);
if unifier.unioned(ndarray_ty_dtype, scalar_ty) {
Ok(ndarray_ty)
} else {
let (expected_ty, actual_ty) = if is_left_ndarray {
(ndarray_ty_dtype, scalar_ty)
} else {
(scalar_ty, ndarray_ty_dtype)
};
Err(format!(
"Expected right-hand side operand to be {}, got {}",
unifier.stringify(expected_ty),
unifier.stringify(actual_ty),
))
}
}
}
/// Returns the return type given a binary operator and its primitive operands.
pub fn typeof_binop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: &Operator,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(match op {
Operator::Add
| Operator::Sub
| Operator::Mult
| Operator::Mod
| Operator::FloorDiv => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
}
}
Operator::MatMult => typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?,
Operator::Div => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if unifier.unioned(lhs, rhs) {
primitives.float
} else {
return Ok(None)
}
}
Operator::Pow => {
if is_left_ndarray || is_right_ndarray {
typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?
} else if [primitives.int32, primitives.int64, primitives.uint32, primitives.uint64, primitives.float].into_iter().any(|ty| unifier.unioned(lhs, ty)) {
lhs
} else {
return Ok(None)
}
}
Operator::LShift
| Operator::RShift => lhs,
Operator::BitOr
| Operator::BitXor
| Operator::BitAnd => {
if unifier.unioned(lhs, rhs) {
lhs
} else {
return Ok(None)
}
}
}))
}
pub fn typeof_unaryop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
op: &Unaryop,
operand: Type,
) -> Result<Option<Type>, String> {
if *op == Unaryop::Not && operand.obj_id(unifier).is_some_and(|id| id == primitives.ndarray.obj_id(unifier).unwrap()) {
return Err("The truth value of an array with more than one element is ambiguous".to_string())
}
Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
})
}
/// Returns the return type given a comparison operator and its primitive operands.
pub fn typeof_cmpop(
unifier: &mut Unifier,
primitives: &PrimitiveStore,
_op: &Cmpop,
lhs: Type,
rhs: Type,
) -> Result<Option<Type>, String> {
let is_left_ndarray = lhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_right_ndarray = rhs
.obj_id(unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
Ok(Some(if is_left_ndarray || is_right_ndarray {
let brd = typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)?;
let (_, ndims) = unpack_ndarray_var_tys(unifier, brd);
make_ndarray_ty(unifier, primitives, Some(primitives.bool), Some(ndims))
} else if unifier.unioned(lhs, rhs) {
primitives.bool
} else {
return Ok(None)
}))
}
pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) {
@ -299,39 +510,60 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t,
uint32: uint32_t,
uint64: uint64_t,
ndarray: ndarray_t,
..
} = *store;
let size_t = store.usize();
/* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] {
impl_basic_arithmetic(unifier, store, t, &[t], t);
impl_pow(unifier, store, t, &[t], t);
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t]);
impl_floordiv(unifier, store, t, &[t], t);
impl_mod(unifier, store, t, &[t], t);
impl_invert(unifier, store, t);
impl_not(unifier, store, t);
impl_comparison(unifier, store, t, t);
impl_eq(unifier, store, t);
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t, ndarray_int_t], None);
impl_eq(unifier, store, t, &[t, ndarray_int_t], None);
}
for t in [int32_t, int64_t] {
impl_sign(unifier, store, t);
impl_sign(unifier, store, t, Some(t));
}
/* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t);
impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t);
impl_div(unifier, store, float_t, &[float_t]);
impl_floordiv(unifier, store, float_t, &[float_t], float_t);
impl_mod(unifier, store, float_t, &[float_t], float_t);
impl_sign(unifier, store, float_t);
impl_not(unifier, store, float_t);
impl_comparison(unifier, store, float_t, float_t);
impl_eq(unifier, store, float_t);
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_eq(unifier, store, float_t, &[float_t, ndarray_float_t], None);
/* bool ======== */
impl_not(unifier, store, bool_t);
impl_eq(unifier, store, bool_t);
let ndarray_bool_t = make_ndarray_ty(unifier, store, Some(bool_t), None);
impl_not(unifier, store, bool_t, Some(bool_t));
impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None);
/* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_sign(unifier, store, ndarray_t, Some(ndarray_t));
impl_invert(unifier, store, ndarray_t, Some(ndarray_t));
impl_eq(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_comparison(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
}

View File

@ -9,11 +9,11 @@ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
helper::PRIMITIVE_DEF_IDS,
numpy::{make_ndarray_ty, unpack_ndarray_tvars},
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelContext,
},
};
use itertools::izip;
use itertools::{Itertools, izip};
use nac3parser::ast::{
self,
fold::{self, Fold},
@ -59,6 +59,16 @@ pub struct PrimitiveStore {
}
impl PrimitiveStore {
/// Returns a [`Type`] representing a signed representation of `size_t`.
#[must_use]
pub fn isize(&self) -> Type {
match self.size_t {
32 => self.int32,
64 => self.int64,
_ => unreachable!(),
}
}
/// Returns a [Type] representing `size_t`.
#[must_use]
pub fn usize(&self) -> Type {
@ -539,7 +549,9 @@ impl<'a> Fold<()> for Inferencer<'a> {
ExprKind::BinOp { left, op, right } => {
Some(self.infer_bin_ops(expr.location, left, op, right, false)?)
}
ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
ExprKind::UnaryOp { op, operand } => {
Some(self.infer_unary_ops(expr.location, op, operand)?)
}
ExprKind::Compare { left, ops, comparators } => {
Some(self.infer_compare(left, ops, comparators)?)
}
@ -1193,8 +1205,11 @@ impl<'a> Inferencer<'a> {
right: &ast::Expr<Option<Type>>,
is_aug_assign: bool,
) -> InferenceResult {
let left_ty = left.custom.unwrap();
let right_ty = right.custom.unwrap();
let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
self.unifier.get_ty_immutable(left_ty).as_ref()
{
let (binop_name, binop_assign_name) = (
binop_name(op).into(),
@ -1209,22 +1224,45 @@ impl<'a> Inferencer<'a> {
} else {
binop_name(op).into()
};
let ret = if is_aug_assign {
// The type of augmented assignment operator should never change
Some(left_ty)
} else {
typeof_binop(
self.unifier,
self.primitives,
op,
left_ty,
right_ty,
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
};
self.build_method_call(
location,
method,
left.custom.unwrap(),
vec![right.custom.unwrap()],
None,
left_ty,
vec![right_ty],
ret,
)
}
fn infer_unary_ops(
&mut self,
location: Location,
op: &ast::Unaryop,
operand: &ast::Expr<Option<Type>>,
) -> InferenceResult {
let method = unaryop_name(op).into();
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], None)
let ret = typeof_unaryop(
self.unifier,
self.primitives,
op,
operand.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], ret)
}
fn infer_compare(
@ -1233,22 +1271,45 @@ impl<'a> Inferencer<'a> {
ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>],
) -> InferenceResult {
let boolean = self.primitives.bool;
if ops.len() > 1 && once(left).chain(comparators).any(|expr| expr.custom.unwrap().obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray)) {
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
}
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(c)
.ok_or_else(|| HashSet::from([
"unsupported comparator".to_string()
]))?
.into();
let ret = typeof_cmpop(
self.unifier,
self.primitives,
c,
a.custom.unwrap(),
b.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
self.build_method_call(
a.location,
method,
a.custom.unwrap(),
vec![b.custom.unwrap()],
Some(boolean),
ret,
)?;
}
Ok(boolean)
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left);
let res_rhs = comparators.iter().rev().nth(0).unwrap();
let res_op = ops.iter().rev().nth(0).unwrap();
Ok(typeof_cmpop(
self.unifier,
self.primitives,
res_op,
res_lhs.custom.unwrap(),
res_rhs.custom.unwrap(),
).unwrap().unwrap())
}
/// Infers the type of a subscript expression on an `ndarray`.
@ -1334,7 +1395,7 @@ impl<'a> Inferencer<'a> {
let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }),
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims))
}
@ -1347,7 +1408,7 @@ impl<'a> Inferencer<'a> {
ExprKind::Constant { value: ast::Constant::Int(val), .. } => {
match &*self.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => {
@ -1379,9 +1440,18 @@ impl<'a> Inferencer<'a> {
Ok(ty)
}
TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => {
let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap());
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?;
let valid_index_tys = [
self.primitives.int32,
self.primitives.isize(),
].into_iter().unique().collect_vec();
let valid_index_ty = self.unifier.get_fresh_var_with_range(
valid_index_tys.as_slice(),
None,
None,
).0;
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims)
}
_ => unreachable!(),

View File

@ -135,10 +135,15 @@ impl TestEnvironment {
fields: HashMap::new(),
params: VarMap::new(),
});
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: HashMap::new(),
params: VarMap::new(),
params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
});
let primitives = PrimitiveStore {
int32,

View File

@ -61,15 +61,21 @@ pub enum RecordKey {
}
impl Type {
// a wrapper function for cleaner code so that we don't need to
// write this long pattern matching just to get the field `obj_id`
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() {
*obj_id
/// Wrapper function for cleaner code so that we don't need to write this long pattern matching
/// just to get the field `obj_id`.
#[must_use]
pub fn obj_id(self, unifier: &Unifier) -> Option<DefinitionId> {
if let TypeEnum::TObj { obj_id, .. } = &*unifier.get_ty_immutable(self) {
Some(*obj_id)
} else {
unreachable!("expect a object type")
None
}
}
#[deprecated = "Prefer using `Type::obj_id` instead to handle non-TObj cases."]
pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId {
self.obj_id(unifier).expect("expect a object type")
}
}
impl From<&RecordKey> for StrRef {
@ -765,12 +771,8 @@ impl Unifier {
// If the types don't match, try to implicitly promote integers
if !self.unioned(ty, value_ty) {
let num_val = match *value {
SymbolValue::I32(v) => v as i128,
SymbolValue::I64(v) => v as i128,
SymbolValue::U32(v) => v as i128,
SymbolValue::U64(v) => v as i128,
_ => return Self::incompatible_types(a, b),
let Ok(num_val) = i128::try_from(value.clone()) else {
return Self::incompatible_types(a, b)
};
let can_convert = if self.unioned(ty, primitives.int32) {

View File

@ -1,3 +1,7 @@
@extern
def output_bool(x: bool):
...
@extern
def output_int32(x: int32):
...
@ -6,6 +10,29 @@ def output_int32(x: int32):
def output_float64(x: float):
...
def output_ndarray_bool_2(n: ndarray[bool, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_bool(n[r][c])
def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]):
for i in range(len(n)):
output_int32(n[i])
def output_ndarray_int32_2(n: ndarray[int32, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_int32(n[r][c])
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
for i in range(len(n)):
output_float64(n[i])
def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_float64(n[r][c])
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
@ -19,53 +46,582 @@ def test_ndarray_empty():
def test_ndarray_zeros():
n: ndarray[float, 1] = np_zeros([1])
output_float64(n[0])
output_ndarray_float_1(n)
def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])
output_float64(n[0])
output_ndarray_float_1(n)
def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0)
output_float64(n_float[0])
output_ndarray_float_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2)
output_int32(n_i32[0])
output_ndarray_int32_1(n_i32)
def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2)
n0: ndarray[float, 1] = n[0]
v: float = n0[0]
output_float64(v)
output_ndarray_float_2(n)
def test_ndarray_identity():
n: ndarray[float, 2] = np_identity(2)
output_float64(n[0][0])
output_float64(n[0][1])
output_float64(n[1][0])
output_float64(n[1][1])
output_ndarray_float_2(n)
def test_ndarray_fill():
n: ndarray[float, 2] = np_empty([2, 2])
n.fill(1.0)
output_float64(n[0][0])
output_float64(n[0][1])
output_float64(n[1][0])
output_float64(n[1][1])
output_ndarray_float_2(n)
def test_ndarray_copy():
x: ndarray[float, 2] = np_identity(2)
y = x.copy()
x.fill(0.0)
output_float64(x[0][0])
output_float64(x[0][1])
output_float64(x[1][0])
output_float64(x[1][1])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
output_float64(y[0][0])
output_float64(y[0][1])
output_float64(y[1][0])
output_float64(y[1][1])
def test_ndarray_add():
x = np_identity(2)
y = x + np_ones([2, 2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x + np_ones([2])
y = x + np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 1.0 + x
y = 1.0 + x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x + 1.0
y = x + 1.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_iadd():
x = np_identity(2)
x += np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_iadd_broadcast():
x = np_identity(2)
x += np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_iadd_broadcast_scalar():
x = np_identity(2)
x += 1.0
output_ndarray_float_2(x)
def test_ndarray_sub():
x = np_ones([2, 2])
y = x - np_identity(2)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x - np_ones([2])
y = x - np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 1.0 - x
y = 1.0 - x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x - 1
y = x - 1.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_isub():
x = np_ones([2, 2])
x -= np_identity(2)
output_ndarray_float_2(x)
def test_ndarray_isub_broadcast():
x = np_identity(2)
x -= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_isub_broadcast_scalar():
x = np_identity(2)
x -= 1.0
output_ndarray_float_2(x)
def test_ndarray_mul():
x = np_ones([2, 2])
y = x * np_identity(2)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x * np_ones([2])
y = x * np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 2.0 * x
y = 2.0 * x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x * 2.0
y = x * 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_imul():
x = np_ones([2, 2])
x *= np_identity(2)
output_ndarray_float_2(x)
def test_ndarray_imul_broadcast():
x = np_identity(2)
x *= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_imul_broadcast_scalar():
x = np_identity(2)
x *= 2.0
output_ndarray_float_2(x)
def test_ndarray_truediv():
x = np_identity(2)
y = x / np_ones([2, 2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x / np_ones([2])
y = x / np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 / x
y = 2.0 / x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x / 2.0
y = x / 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_itruediv():
x = np_identity(2)
x /= np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_itruediv_broadcast():
x = np_identity(2)
x /= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_itruediv_broadcast_scalar():
x = np_identity(2)
x /= 2.0
output_ndarray_float_2(x)
def test_ndarray_floordiv():
x = np_identity(2)
y = x // np_ones([2, 2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x // np_ones([2])
y = x // np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 // x
y = 2.0 // x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x // 2.0
y = x // 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_ifloordiv():
x = np_identity(2)
x //= np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_ifloordiv_broadcast():
x = np_identity(2)
x //= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_ifloordiv_broadcast_scalar():
x = np_identity(2)
x //= 2.0
output_ndarray_float_2(x)
def test_ndarray_mod():
x = np_identity(2)
y = x % np_full([2, 2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x % np_ones([2])
y = x % np_full([2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 % x
y = 2.0 % x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x % 2.0
y = x % 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_imod():
x = np_identity(2)
x %= np_full([2, 2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_imod_broadcast():
x = np_identity(2)
x %= np_full([2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_imod_broadcast_scalar():
x = np_identity(2)
x %= 2.0
output_ndarray_float_2(x)
def test_ndarray_pow():
x = np_identity(2)
y = x ** np_full([2, 2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x ** np_full([2], 2.0)
y = x ** np_full([2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 2.0 ** x
y = 2.0 ** x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x % 2.0
y = x ** 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_ipow():
x = np_identity(2)
x **= np_full([2, 2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_ipow_broadcast():
x = np_identity(2)
x **= np_full([2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_ipow_broadcast_scalar():
x = np_identity(2)
x **= 2.0
output_ndarray_float_2(x)
def test_ndarray_pos():
x_int32 = np_full([2, 2], -2)
y_int32 = +x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
x_float = np_full([2, 2], -2.0)
y_float = +x_float
output_ndarray_float_2(x_float)
output_ndarray_float_2(y_float)
def test_ndarray_neg():
x_int32 = np_full([2, 2], -2)
y_int32 = -x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
x_float = np_full([2, 2], 2.0)
y_float = -x_float
output_ndarray_float_2(x_float)
output_ndarray_float_2(y_float)
def test_ndarray_inv():
x_int32 = np_full([2, 2], -2)
y_int32 = ~x_int32
output_ndarray_int32_2(x_int32)
output_ndarray_int32_2(y_int32)
def test_ndarray_eq():
x = np_identity(2)
y = x == np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast():
x = np_identity(2)
y = x == np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 == x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_eq_broadcast_rhs_scalar():
x = np_identity(2)
y = x == 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne():
x = np_identity(2)
y = x != np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast():
x = np_identity(2)
y = x != np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 != x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ne_broadcast_rhs_scalar():
x = np_identity(2)
y = x != 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt():
x = np_identity(2)
y = x < np_full([2, 2], 1.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast():
x = np_identity(2)
y = x < np_full([2], 1.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast_lhs_scalar():
x = np_identity(2)
y = 1.0 < x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_lt_broadcast_rhs_scalar():
x = np_identity(2)
y = x < 1.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le():
x = np_identity(2)
y = x <= np_full([2, 2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast():
x = np_identity(2)
y = x <= np_full([2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.5 <= x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_le_broadcast_rhs_scalar():
x = np_identity(2)
y = x <= 0.5
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt():
x = np_identity(2)
y = x > np_full([2, 2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast():
x = np_identity(2)
y = x > np_full([2], 0.0)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.0 > x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_gt_broadcast_rhs_scalar():
x = np_identity(2)
y = x > 0.0
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge():
x = np_identity(2)
y = x >= np_full([2, 2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast():
x = np_identity(2)
y = x >= np_full([2], 0.5)
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast_lhs_scalar():
x = np_identity(2)
y = 0.5 >= x
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def test_ndarray_ge_broadcast_rhs_scalar():
x = np_identity(2)
y = x >= 0.5
output_ndarray_float_2(x)
output_ndarray_bool_2(y)
def run() -> int32:
test_ndarray_ctor()
@ -77,5 +633,81 @@ def run() -> int32:
test_ndarray_identity()
test_ndarray_fill()
test_ndarray_copy()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
test_ndarray_pos()
test_ndarray_neg()
test_ndarray_inv()
test_ndarray_eq()
test_ndarray_eq_broadcast()
test_ndarray_eq_broadcast_lhs_scalar()
test_ndarray_eq_broadcast_rhs_scalar()
test_ndarray_ne()
test_ndarray_ne_broadcast()
test_ndarray_ne_broadcast_lhs_scalar()
test_ndarray_ne_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_lt_broadcast()
test_ndarray_lt_broadcast_lhs_scalar()
test_ndarray_lt_broadcast_rhs_scalar()
test_ndarray_lt()
test_ndarray_le_broadcast()
test_ndarray_le_broadcast_lhs_scalar()
test_ndarray_le_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_gt_broadcast()
test_ndarray_gt_broadcast_lhs_scalar()
test_ndarray_gt_broadcast_rhs_scalar()
test_ndarray_gt()
test_ndarray_ge_broadcast()
test_ndarray_ge_broadcast_lhs_scalar()
test_ndarray_ge_broadcast_rhs_scalar()
return 0