diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index ccba682..5d7621c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -5,7 +5,7 @@ use nac3core::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -654,7 +654,7 @@ impl InnerResolver { } } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty); + let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { assert!(matches!( diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 84d86b2..3df2d63 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -11,7 +11,7 @@ use crate::codegen::{ stmt::gen_for_callback_incrementing, }; -/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of +/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// elements. pub trait ArrayLikeValue<'ctx> { /// Returns the element type of this array-like value. diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 64df45e..c0f44d1 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -16,6 +16,7 @@ use crate::{ get_llvm_abi_type, irrt::*, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, + numpy, stmt::{gen_raise, gen_var}, CodeGenContext, CodeGenTask, }, @@ -23,7 +24,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::make_ndarray_ty, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -1129,6 +1130,76 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i") ); Ok(Some(res.into())) + } else if ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray || ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let is_ndarray1 = ty1.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray; + let is_ndarray2 = ty2.get_obj_id(&ctx.unifier) == PRIMITIVE_DEF_IDS.ndarray; + + if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + + assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + let left_val = NDArrayValue::from_ptr_val( + left_val.into_pointer_value(), + llvm_usize, + None + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype1, + if is_aug_assign { Some(left_val) } else { None }, + (left_val.as_ptr_value().into(), false), + (right_val.into(), false), + |generator, ctx, elem_ty, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), lhs), + op, + (&Some(elem_ty), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } else { + let (ndarray_dtype, _) = unpack_ndarray_var_tys( + &mut ctx.unifier, + if is_ndarray1 { ty1 } else { ty2 }, + ); + let ndarray_val = NDArrayValue::from_ptr_val( + if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), + llvm_usize, + None, + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype, + if is_aug_assign { Some(ndarray_val) } else { None }, + (left_val, !is_ndarray1), + (right_val, !is_ndarray2), + |generator, ctx, elem_ty, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(elem_ty), lhs), + op, + (&Some(elem_ty), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, elem_ty) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index db2825d..7f3c65d 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,6 +18,8 @@ use crate::{ CodeGenContext, CodeGenerator, irrt::{ + call_ndarray_calc_broadcast, + call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, @@ -344,6 +346,67 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( ) } +/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value +/// with broadcast-compatible shapes. +fn ndarray_broadcast_fill<'ctx, G, ValueFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: NDArrayValue<'ctx>, + lhs: (BasicValueEnum<'ctx>, bool), + rhs: (BasicValueEnum<'ctx>, bool), + value_fn: ValueFn, +) -> Result, String> + where + G: CodeGenerator + ?Sized, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (lhs_val, lhs_scalar) = lhs; + let (rhs_val, rhs_scalar) = rhs; + + assert!(!(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type()); + + ndarray_fill_indexed( + generator, + ctx, + res, + |generator, ctx, idx| { + let lhs_elem = if lhs_scalar { + lhs_val + } else { + let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx); + + unsafe { + lhs.data().get_unchecked(ctx, generator, lhs_idx, None) + } + }; + + let rhs_elem = if rhs_scalar { + rhs_val + } else { + let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx); + + unsafe { + rhs.data().get_unchecked(ctx, generator, rhs_idx, None) + } + }; + + debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type()); + + value_fn(generator, ctx, elem_ty, (lhs_elem, rhs_elem)) + }, + )?; + + Ok(res) +} + /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -579,6 +642,108 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } +/// LLVM-typed implementation for computing elementwise binary operations on two input operands. +/// +/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output +/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. +/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the +/// `value_fn` arguments tuple for all output elements. +/// +/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` +/// (which would be accessed by its broadcast index) or as a scalar value (which would be +/// broadcast to all elements). +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be +/// written to a new `ndarray`. +/// * `value_fn` - Function mapping the two input elements into the result. +/// +/// # Panic +/// +/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. +pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: Option>, + lhs: (BasicValueEnum<'ctx>, bool), + rhs: (BasicValueEnum<'ctx>, bool), + value_fn: ValueFn, +) -> Result, String> + where + G: CodeGenerator, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, Type, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (lhs_val, lhs_scalar) = lhs; + let (rhs_val, rhs_scalar) = rhs; + + assert!(!(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type()); + + let ndarray = res.unwrap_or_else(|| { + if lhs_scalar && rhs_scalar { + let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + + let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); + + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &ndarray_dims, + |generator, ctx, v| { + Ok(v.size(ctx, generator)) + }, + |generator, ctx, v, idx| { + unsafe { + Ok(v.get_typed_unchecked(ctx, generator, idx, None)) + } + }, + ).unwrap() + } else { + let ndarray = NDArrayValue::from_ptr_val( + if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), + llvm_usize, + None, + ); + + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &ndarray, + |_, ctx, v| { + Ok(v.load_ndims(ctx)) + }, + |generator, ctx, v, idx| { + unsafe { + Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None)) + } + }, + ).unwrap() + } + }); + + ndarray_broadcast_fill( + generator, + ctx, + elem_ty, + ndarray, + lhs, + rhs, + |generator, ctx, elem_ty, elems| { + value_fn(generator, ctx, elem_ty, elems) + }, + )?; + + Ok(ndarray) +} + /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1e0ab2c..fd5d51a 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( /// body(x); /// } /// ``` -/// +/// /// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// as the type of the loop variable. /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index de41ac1..c9cea93 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -170,13 +170,13 @@ impl SymbolValue { /// Returns the [`TypeAnnotation`] representing the data type of this value. pub fn get_type_annotation(&self, primitives: &PrimitiveStore, unifier: &mut Unifier) -> TypeAnnotation { match self { - SymbolValue::Bool(..) => TypeAnnotation::Primitive(primitives.bool), - SymbolValue::Double(..) => TypeAnnotation::Primitive(primitives.float), - SymbolValue::I32(..) => TypeAnnotation::Primitive(primitives.int32), - SymbolValue::I64(..) => TypeAnnotation::Primitive(primitives.int64), - SymbolValue::U32(..) => TypeAnnotation::Primitive(primitives.uint32), - SymbolValue::U64(..) => TypeAnnotation::Primitive(primitives.uint64), - SymbolValue::Str(..) => TypeAnnotation::Primitive(primitives.str), + SymbolValue::Bool(..) + | SymbolValue::Double(..) + | SymbolValue::I32(..) + | SymbolValue::I64(..) + | SymbolValue::U32(..) + | SymbolValue::U64(..) + | SymbolValue::Str(..) => TypeAnnotation::Primitive(self.get_type(primitives, unifier)), SymbolValue::Tuple(vs) => { let vs_tys = vs .iter() @@ -230,6 +230,36 @@ impl Display for SymbolValue { } } +impl TryFrom for u64 { + type Error = (); + + /// TODO + fn try_from(value: SymbolValue) -> Result { + match value { + SymbolValue::I32(v) => Ok(v as u64), + SymbolValue::I64(v) => u64::try_from(v).map_err(|_| ()), + SymbolValue::U32(v) => Ok(v as u64), + SymbolValue::U64(v) => Ok(v), + _ => Err(()), + } + } +} + +impl TryFrom for i128 { + type Error = (); + + /// TODO + fn try_from(value: SymbolValue) -> Result { + match value { + SymbolValue::I32(v) => Ok(v as i128), + SymbolValue::I64(v) => Ok(v as i128), + SymbolValue::U32(v) => Ok(v as i128), + SymbolValue::U64(v) => Ok(v as i128), + _ => Err(()), + } + } +} + pub trait StaticValue { /// Returns a unique identifier for this value. fn get_unique_identifier(&self) -> u64; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index d671eba..61f4e31 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some("N".into()), None, ); + let size_t = primitives.0.usize(); + let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ ("__name__".into(), int32, true), @@ -345,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .nth(1) .map(|(var_id, ty)| (*ty, *var_id)) .unwrap(); + let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var( + size_t, + Some("ndarray_ndims".into()), + None, + ); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap(); + let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap(); + let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap(); + let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap(); + let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap(); + let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap(); + let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap(); + let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap(); + let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap(); + let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap(); + let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap(); + let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap(); + let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap(); + let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap(); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( @@ -524,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { methods: vec![ ("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)), + ("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)), + ("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)), + ("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)), + ("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)), + ("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)), + ("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)), + ("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)), + ("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)), + ("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)), + ("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)), + ("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)), + ("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)), + ("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)), + ("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)), ], ancestors: Vec::default(), constructor: None, @@ -562,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__add__".into(), + simple_name: "__add__".into(), + signature: ndarray_add_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__sub__".into(), + simple_name: "__sub__".into(), + signature: ndarray_sub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mul__".into(), + simple_name: "__mul__".into(), + signature: ndarray_mul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__truediv__".into(), + simple_name: "__truediv__".into(), + signature: ndarray_truediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__floordiv__".into(), + simple_name: "__floordiv__".into(), + signature: ndarray_floordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mod__".into(), + simple_name: "__mod__".into(), + signature: ndarray_mod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__pow__".into(), + simple_name: "__pow__".into(), + signature: ndarray_pow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__iadd__".into(), + simple_name: "__iadd__".into(), + signature: ndarray_iadd_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__isub__".into(), + simple_name: "__isub__".into(), + signature: ndarray_isub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imul__".into(), + simple_name: "__imul__".into(), + signature: ndarray_imul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__itruediv__".into(), + simple_name: "__itruediv__".into(), + signature: ndarray_itruediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ifloordiv__".into(), + simple_name: "__ifloordiv__".into(), + signature: ndarray_ifloordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imod__".into(), + simple_name: "__imod__".into(), + signature: ndarray_imod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ipow__".into(), + simple_name: "__ipow__".into(), + signature: ndarray_ipow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 8a3908e..f00d754 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::toplevel::numpy::subst_ndarray_tvars; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; @@ -226,11 +227,57 @@ impl TopLevelComposer { (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), })); + let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_binop_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_binop_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); + let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_truediv_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_truediv_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, fields: Mapping::from([ ("copy".into(), (ndarray_copy_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)), + ("__add__".into(), (ndarray_binop_fun_ty, true)), + ("__sub__".into(), (ndarray_binop_fun_ty, true)), + ("__mul__".into(), (ndarray_binop_fun_ty, true)), + ("__truediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__floordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__mod__".into(), (ndarray_binop_fun_ty, true)), + ("__pow__".into(), (ndarray_binop_fun_ty, true)), + ("__iadd__".into(), (ndarray_binop_fun_ty, true)), + ("__isub__".into(), (ndarray_binop_fun_ty, true)), + ("__imul__".into(), (ndarray_binop_fun_ty, true)), + ("__itruediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__imod__".into(), (ndarray_binop_fun_ty, true)), + ("__ipow__".into(), (ndarray_binop_fun_ty, true)), ]), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -238,7 +285,16 @@ impl TopLevelComposer { ]), }); + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None); + unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap(); + unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap(); + + let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None); + unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap(); let primitives = PrimitiveStore { int32, diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index af31495..ad44f34 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -241,8 +241,14 @@ pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Typ } /// `Div` -pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { - impl_binop(unifier, store, ty, other_ty, Some(store.float), &[Operator::Div]); +pub fn impl_div( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Option, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]); } /// `FloorDiv` @@ -437,18 +443,21 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, uint32: uint32_t, uint64: uint64_t, + ndarray: ndarray_t, .. } = *store; + let size_t = store.usize(); /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { - impl_basic_arithmetic(unifier, store, t, &[t], Some(t)); - impl_pow(unifier, store, t, &[t], Some(t)); + let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None); + impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None); + impl_pow(unifier, store, t, &[t, ndarray_int_t], None); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); - impl_div(unifier, store, t, &[t]); - impl_floordiv(unifier, store, t, &[t], Some(t)); - impl_mod(unifier, store, t, &[t], Some(t)); + impl_div(unifier, store, t, &[t, ndarray_int_t], None); + impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None); + impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_invert(unifier, store, t); impl_not(unifier, store, t); impl_comparison(unifier, store, t, t); @@ -459,11 +468,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie } /* float ======== */ - impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t)); - impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t)); - impl_div(unifier, store, float_t, &[float_t]); - impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t)); - impl_mod(unifier, store, float_t, &[float_t], Some(float_t)); + let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); + let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None); + impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None); + impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_sign(unifier, store, float_t); impl_not(unifier, store, float_t); impl_comparison(unifier, store, float_t, float_t); @@ -472,4 +483,15 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* bool ======== */ impl_not(unifier, store, bool_t); impl_eq(unifier, store, bool_t); + + /* ndarray ===== */ + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); + let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); + let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); + let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); + impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); + impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 5fb89e0..523f270 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1203,8 +1203,11 @@ impl<'a> Inferencer<'a> { right: &ast::Expr>, is_aug_assign: bool, ) -> InferenceResult { + let left_ty = left.custom.unwrap(); + let right_ty = right.custom.unwrap(); + let method = if let TypeEnum::TObj { fields, .. } = - self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + self.unifier.get_ty_immutable(left_ty).as_ref() { let (binop_name, binop_assign_name) = ( binop_name(op).into(), @@ -1219,12 +1222,26 @@ impl<'a> Inferencer<'a> { } else { binop_name(op).into() }; + + let ret = if is_aug_assign { + // The type of an augmented assignment operator should never change + Some(left_ty) + } else { + typeof_binop( + self.unifier, + self.primitives, + op, + left_ty, + right_ty, + ).map_err(|e| HashSet::from([format!("{} (at {})", e, location)]))? + }; + self.build_method_call( location, method, - left.custom.unwrap(), - vec![right.custom.unwrap()], - None, + left_ty, + vec![right_ty], + ret, ) } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index c23b3d8..c36893e 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -765,12 +765,8 @@ impl Unifier { // If the types don't match, try to implicitly promote integers if !self.unioned(ty, value_ty) { - let num_val = match *value { - SymbolValue::I32(v) => v as i128, - SymbolValue::I64(v) => v as i128, - SymbolValue::U32(v) => v as i128, - SymbolValue::U64(v) => v as i128, - _ => return Self::incompatible_types(a, b), + let Ok(num_val) = i128::try_from(value.clone()) else { + return Self::incompatible_types(a, b) }; let can_convert = if self.unioned(ty, primitives.int32) { diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 353ebe5..87f0b36 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -6,6 +6,19 @@ def output_int32(x: int32): def output_float64(x: float): ... +def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]): + for i in range(len(n)): + output_int32(n[i]) + +def output_ndarray_float_1(n: ndarray[float, Literal[1]]): + for i in range(len(n)): + output_float64(n[i]) + +def output_ndarray_float_2(n: ndarray[float, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_float64(n[r][c]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -19,53 +32,381 @@ def test_ndarray_empty(): def test_ndarray_zeros(): n: ndarray[float, 1] = np_zeros([1]) - output_float64(n[0]) + output_ndarray_float_1(n) def test_ndarray_ones(): n: ndarray[float, 1] = np_ones([1]) - output_float64(n[0]) + output_ndarray_float_1(n) def test_ndarray_full(): n_float: ndarray[float, 1] = np_full([1], 2.0) - output_float64(n_float[0]) + output_ndarray_float_1(n_float) n_i32: ndarray[int32, 1] = np_full([1], 2) - output_int32(n_i32[0]) + output_ndarray_int32_1(n_i32) def test_ndarray_eye(): n: ndarray[float, 2] = np_eye(2) - n0: ndarray[float, 1] = n[0] - v: float = n0[0] - output_float64(v) + output_ndarray_float_2(n) def test_ndarray_identity(): n: ndarray[float, 2] = np_identity(2) - output_float64(n[0][0]) - output_float64(n[0][1]) - output_float64(n[1][0]) - output_float64(n[1][1]) + output_ndarray_float_2(n) def test_ndarray_fill(): n: ndarray[float, 2] = np_empty([2, 2]) n.fill(1.0) - output_float64(n[0][0]) - output_float64(n[0][1]) - output_float64(n[1][0]) - output_float64(n[1][1]) + output_ndarray_float_2(n) def test_ndarray_copy(): x: ndarray[float, 2] = np_identity(2) y = x.copy() x.fill(0.0) - output_float64(x[0][0]) - output_float64(x[0][1]) - output_float64(x[1][0]) - output_float64(x[1][1]) + output_ndarray_float_2(x) + output_ndarray_float_2(y) - output_float64(y[0][0]) - output_float64(y[0][1]) - output_float64(y[1][0]) - output_float64(y[1][1]) +def test_ndarray_add(): + x = np_identity(2) + y = x + np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x + np_ones([2]) + y = x + np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 1.0 + x + y = 1.0 + x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x + 1.0 + y = x + 1.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_iadd(): + x = np_identity(2) + x += np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_iadd_broadcast(): + x = np_identity(2) + x += np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_iadd_broadcast_scalar(): + x = np_identity(2) + x += 1.0 + + output_ndarray_float_2(x) + +def test_ndarray_sub(): + x = np_ones([2, 2]) + y = x - np_identity(2) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x - np_ones([2]) + y = x - np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 1.0 - x + y = 1.0 - x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x - 1 + y = x - 1.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_isub(): + x = np_ones([2, 2]) + x -= np_identity(2) + + output_ndarray_float_2(x) + +def test_ndarray_isub_broadcast(): + x = np_identity(2) + x -= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_isub_broadcast_scalar(): + x = np_identity(2) + x -= 1.0 + + output_ndarray_float_2(x) + +def test_ndarray_mul(): + x = np_ones([2, 2]) + y = x * np_identity(2) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x * np_ones([2]) + y = x * np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 2.0 * x + y = 2.0 * x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x * 2.0 + y = x * 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_imul(): + x = np_ones([2, 2]) + x *= np_identity(2) + + output_ndarray_float_2(x) + +def test_ndarray_imul_broadcast(): + x = np_identity(2) + x *= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_imul_broadcast_scalar(): + x = np_identity(2) + x *= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_truediv(): + x = np_identity(2) + y = x / np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x / np_ones([2]) + y = x / np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 / x + y = 2.0 / x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x / 2.0 + y = x / 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_itruediv(): + x = np_identity(2) + x /= np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_itruediv_broadcast(): + x = np_identity(2) + x /= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_itruediv_broadcast_scalar(): + x = np_identity(2) + x /= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_floordiv(): + x = np_identity(2) + y = x // np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x // np_ones([2]) + y = x // np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 // x + y = 2.0 // x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x // 2.0 + y = x // 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_ifloordiv(): + x = np_identity(2) + x //= np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_ifloordiv_broadcast(): + x = np_identity(2) + x //= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_ifloordiv_broadcast_scalar(): + x = np_identity(2) + x //= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_mod(): + x = np_identity(2) + y = x % np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x % np_ones([2]) + y = x % np_full([2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 % x + y = 2.0 % x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x % 2.0 + y = x % 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_imod(): + x = np_identity(2) + x %= np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_imod_broadcast(): + x = np_identity(2) + x %= np_full([2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_imod_broadcast_scalar(): + x = np_identity(2) + x %= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_pow(): + x = np_identity(2) + y = x ** np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x ** np_full([2], 2.0) + y = x ** np_full([2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 2.0 ** x + y = 2.0 ** x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x % 2.0 + y = x ** 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_ipow(): + x = np_identity(2) + x **= np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_ipow_broadcast(): + x = np_identity(2) + x **= np_full([2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_ipow_broadcast_scalar(): + x = np_identity(2) + x **= 2.0 + + output_ndarray_float_2(x) def run() -> int32: test_ndarray_ctor() @@ -77,5 +418,54 @@ def run() -> int32: test_ndarray_identity() test_ndarray_fill() test_ndarray_copy() + test_ndarray_add() + test_ndarray_add_broadcast() + test_ndarray_add_broadcast_lhs_scalar() + test_ndarray_add_broadcast_rhs_scalar() + test_ndarray_iadd() + test_ndarray_iadd_broadcast() + test_ndarray_iadd_broadcast_scalar() + test_ndarray_sub() + test_ndarray_sub_broadcast() + test_ndarray_sub_broadcast_lhs_scalar() + test_ndarray_sub_broadcast_rhs_scalar() + test_ndarray_isub() + test_ndarray_isub_broadcast() + test_ndarray_isub_broadcast_scalar() + test_ndarray_mul() + test_ndarray_mul_broadcast() + test_ndarray_mul_broadcast_lhs_scalar() + test_ndarray_mul_broadcast_rhs_scalar() + test_ndarray_imul() + test_ndarray_imul_broadcast() + test_ndarray_imul_broadcast_scalar() + test_ndarray_truediv() + test_ndarray_truediv_broadcast() + test_ndarray_truediv_broadcast_lhs_scalar() + test_ndarray_truediv_broadcast_rhs_scalar() + test_ndarray_itruediv() + test_ndarray_itruediv_broadcast() + test_ndarray_itruediv_broadcast_scalar() + test_ndarray_floordiv() + test_ndarray_floordiv_broadcast() + test_ndarray_floordiv_broadcast_lhs_scalar() + test_ndarray_floordiv_broadcast_rhs_scalar() + test_ndarray_ifloordiv() + test_ndarray_ifloordiv_broadcast() + test_ndarray_ifloordiv_broadcast_scalar() + test_ndarray_mod() + test_ndarray_mod_broadcast() + test_ndarray_mod_broadcast_lhs_scalar() + test_ndarray_mod_broadcast_rhs_scalar() + test_ndarray_imod() + test_ndarray_imod_broadcast() + test_ndarray_imod_broadcast_scalar() + test_ndarray_pow() + test_ndarray_pow_broadcast() + test_ndarray_pow_broadcast_lhs_scalar() + test_ndarray_pow_broadcast_rhs_scalar() + test_ndarray_ipow() + test_ndarray_ipow_broadcast() + test_ndarray_ipow_broadcast_scalar() return 0