From 6af13a8261510c44e7f55aec4b031cb9ea4ce424 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 13 Mar 2024 11:16:23 +0800 Subject: [PATCH] core: Implement elementwise binary operators Including immediate variants of these operators. --- nac3core/src/codegen/classes.rs | 2 +- nac3core/src/codegen/expr.rs | 75 ++- nac3core/src/codegen/numpy.rs | 195 ++++++++ nac3core/src/codegen/stmt.rs | 2 +- nac3core/src/toplevel/builtins.rs | 245 ++++++++++ nac3core/src/toplevel/helper.rs | 56 +++ ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 40 +- nac3core/src/typecheck/type_inferencer/mod.rs | 25 +- .../src/typecheck/type_inferencer/test.rs | 7 +- nac3core/src/typecheck/typedef/mod.rs | 8 +- nac3standalone/demo/src/ndarray.py | 436 +++++++++++++++++- 16 files changed, 1049 insertions(+), 56 deletions(-) diff --git a/nac3core/src/codegen/classes.rs b/nac3core/src/codegen/classes.rs index 84d86b2b9..3df2d63d0 100644 --- a/nac3core/src/codegen/classes.rs +++ b/nac3core/src/codegen/classes.rs @@ -11,7 +11,7 @@ use crate::codegen::{ stmt::gen_for_callback_incrementing, }; -/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of +/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// elements. pub trait ArrayLikeValue<'ctx> { /// Returns the element type of this array-like value. diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e96878f74..3f8bafe77 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -17,6 +17,7 @@ use crate::{ get_llvm_abi_type, irrt::*, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, + numpy, stmt::{gen_raise, gen_var}, CodeGenContext, CodeGenTask, }, @@ -24,7 +25,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::make_ndarray_ty, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -1129,6 +1130,78 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i") ); Ok(Some(res.into())) + } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) { + let llvm_usize = generator.get_size_type(ctx.ctx); + + let is_ndarray1 = ty1.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + let is_ndarray2 = ty2.obj_id(&ctx.unifier) + .is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray); + + if is_ndarray1 && is_ndarray2 { + let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); + let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2); + + assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2)); + + let left_val = NDArrayValue::from_ptr_val( + left_val.into_pointer_value(), + llvm_usize, + None + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype1, + if is_aug_assign { Some(left_val) } else { None }, + (left_val.as_ptr_value().into(), false), + (right_val, false), + |generator, ctx, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(ndarray_dtype1), lhs), + op, + (&Some(ndarray_dtype2), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } else { + let (ndarray_dtype, _) = unpack_ndarray_var_tys( + &mut ctx.unifier, + if is_ndarray1 { ty1 } else { ty2 }, + ); + let ndarray_val = NDArrayValue::from_ptr_val( + if is_ndarray1 { left_val } else { right_val }.into_pointer_value(), + llvm_usize, + None, + ); + let res = numpy::ndarray_elementwise_binop_impl( + generator, + ctx, + ndarray_dtype, + if is_aug_assign { Some(ndarray_val) } else { None }, + (left_val, !is_ndarray1), + (right_val, !is_ndarray2), + |generator, ctx, (lhs, rhs)| { + gen_binop_expr_with_values( + generator, + ctx, + (&Some(ndarray_dtype), lhs), + op, + (&Some(ndarray_dtype), rhs), + ctx.current_loc, + is_aug_assign, + )?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype) + }, + )?; + + Ok(Some(res.as_ptr_value().into())) + } } else { let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index bb05ef905..454f238e8 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -18,6 +18,8 @@ use crate::{ CodeGenContext, CodeGenerator, irrt::{ + call_ndarray_calc_broadcast, + call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices, call_ndarray_calc_size, }, @@ -338,6 +340,98 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>( ) } +/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of +/// the target `ndarray`. +fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + target: NDArrayValue<'ctx>, + source: NDArrayValue<'ctx>, +) { + let array_ndims = source.load_ndims(ctx); + let broadcast_size = target.load_ndims(ctx); + + ctx.make_assert( + generator, + ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(), + "0:ValueError", + "operands cannot be broadcast together", + [None, None, None], + ctx.current_loc, + ); +} + +/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value +/// with broadcast-compatible shapes. +fn ndarray_broadcast_fill<'ctx, G, ValueFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + res: NDArrayValue<'ctx>, + lhs: (BasicValueEnum<'ctx>, bool), + rhs: (BasicValueEnum<'ctx>, bool), + value_fn: ValueFn, +) -> Result, String> + where + G: CodeGenerator + ?Sized, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (lhs_val, lhs_scalar) = lhs; + let (rhs_val, rhs_scalar) = rhs; + + assert!(!(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type()); + + // Assert that all ndarray operands are broadcastable to the target size + if !lhs_scalar { + let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val); + } + + if !rhs_scalar { + let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val); + } + + ndarray_fill_indexed( + generator, + ctx, + res, + |generator, ctx, idx| { + let lhs_elem = if lhs_scalar { + lhs_val + } else { + let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx); + + unsafe { + lhs.data().get_unchecked(ctx, generator, lhs_idx, None) + } + }; + + let rhs_elem = if rhs_scalar { + rhs_val + } else { + let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx); + + unsafe { + rhs.data().get_unchecked(ctx, generator, rhs_idx, None) + } + }; + + debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type()); + + value_fn(generator, ctx, (lhs_elem, rhs_elem)) + }, + )?; + + Ok(res) +} + /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. /// /// * `elem_ty` - The element type of the `NDArray`. @@ -562,6 +656,107 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>( Ok(ndarray) } +/// LLVM-typed implementation for computing elementwise binary operations on two input operands. +/// +/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output +/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple. +/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the +/// `value_fn` arguments tuple for all output elements. +/// +/// The second element of the tuple indicates whether to treat the operand value as a `ndarray` +/// (which would be accessed by its broadcast index) or as a scalar value (which would be +/// broadcast to all elements). +/// +/// * `elem_ty` - The element type of the `NDArray`. +/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be +/// written to a new `ndarray`. +/// * `value_fn` - Function mapping the two input elements into the result. +/// +/// # Panic +/// +/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`. +pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>( + generator: &mut G, + ctx: &mut CodeGenContext<'ctx, '_>, + elem_ty: Type, + res: Option>, + lhs: (BasicValueEnum<'ctx>, bool), + rhs: (BasicValueEnum<'ctx>, bool), + value_fn: ValueFn, +) -> Result, String> + where + G: CodeGenerator, + ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result, String>, +{ + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (lhs_val, lhs_scalar) = lhs; + let (rhs_val, rhs_scalar) = rhs; + + assert!(!(lhs_scalar && rhs_scalar), + "One of the operands must be a ndarray instance: `{}`, `{}`", + lhs_val.get_type(), + rhs_val.get_type()); + + let ndarray = res.unwrap_or_else(|| { + if lhs_scalar && rhs_scalar { + let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None); + let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None); + + let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val); + + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &ndarray_dims, + |generator, ctx, v| { + Ok(v.size(ctx, generator)) + }, + |generator, ctx, v, idx| { + unsafe { + Ok(v.get_typed_unchecked(ctx, generator, idx, None)) + } + }, + ).unwrap() + } else { + let ndarray = NDArrayValue::from_ptr_val( + if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(), + llvm_usize, + None, + ); + + create_ndarray_dyn_shape( + generator, + ctx, + elem_ty, + &ndarray, + |_, ctx, v| { + Ok(v.load_ndims(ctx)) + }, + |generator, ctx, v, idx| { + unsafe { + Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None)) + } + }, + ).unwrap() + } + }); + + ndarray_broadcast_fill( + generator, + ctx, + ndarray, + lhs, + rhs, + |generator, ctx, elems| { + value_fn(generator, ctx, elems) + }, + )?; + + Ok(ndarray) +} + /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx>( context: &mut CodeGenContext<'ctx, '_>, diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 1e0ab2cd1..fd5d51a05 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>( /// body(x); /// } /// ``` -/// +/// /// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// as the type of the loop variable. /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 4064db66d..c15810ead 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some("N".into()), None, ); + let size_t = primitives.0.usize(); + let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let exception_fields = vec![ ("__name__".into(), int32, true), @@ -345,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .nth(1) .map(|(var_id, ty)| (*ty, *var_id)) .unwrap(); + let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var( + size_t, + Some("ndarray_ndims".into()), + None, + ); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap(); + let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap(); + let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap(); + let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap(); + let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap(); + let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap(); + let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap(); + let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap(); + let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap(); + let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap(); + let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap(); + let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap(); + let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap(); + let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap(); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( @@ -524,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { methods: vec![ ("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)), + ("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)), + ("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)), + ("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)), + ("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)), + ("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)), + ("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)), + ("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)), + ("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)), + ("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)), + ("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)), + ("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)), + ("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)), + ("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)), + ("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)), ], ancestors: Vec::default(), constructor: None, @@ -562,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__add__".into(), + simple_name: "__add__".into(), + signature: ndarray_add_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__sub__".into(), + simple_name: "__sub__".into(), + signature: ndarray_sub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mul__".into(), + simple_name: "__mul__".into(), + signature: ndarray_mul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__truediv__".into(), + simple_name: "__truediv__".into(), + signature: ndarray_truediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__floordiv__".into(), + simple_name: "__floordiv__".into(), + signature: ndarray_floordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mod__".into(), + simple_name: "__mod__".into(), + signature: ndarray_mod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__pow__".into(), + simple_name: "__pow__".into(), + signature: ndarray_pow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__iadd__".into(), + simple_name: "__iadd__".into(), + signature: ndarray_iadd_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__isub__".into(), + simple_name: "__isub__".into(), + signature: ndarray_isub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imul__".into(), + simple_name: "__imul__".into(), + signature: ndarray_imul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__itruediv__".into(), + simple_name: "__itruediv__".into(), + signature: ndarray_itruediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ifloordiv__".into(), + simple_name: "__ifloordiv__".into(), + signature: ndarray_ifloordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imod__".into(), + simple_name: "__imod__".into(), + signature: ndarray_imod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ipow__".into(), + simple_name: "__ipow__".into(), + signature: ndarray_ipow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index e212ac8b5..33f6d48d8 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::toplevel::numpy::subst_ndarray_tvars; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; @@ -231,11 +232,57 @@ impl TopLevelComposer { (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), })); + let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_binop_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_binop_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); + let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_truediv_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_truediv_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, fields: Mapping::from([ ("copy".into(), (ndarray_copy_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)), + ("__add__".into(), (ndarray_binop_fun_ty, true)), + ("__sub__".into(), (ndarray_binop_fun_ty, true)), + ("__mul__".into(), (ndarray_binop_fun_ty, true)), + ("__truediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__floordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__mod__".into(), (ndarray_binop_fun_ty, true)), + ("__pow__".into(), (ndarray_binop_fun_ty, true)), + ("__iadd__".into(), (ndarray_binop_fun_ty, true)), + ("__isub__".into(), (ndarray_binop_fun_ty, true)), + ("__imul__".into(), (ndarray_binop_fun_ty, true)), + ("__itruediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__imod__".into(), (ndarray_binop_fun_ty, true)), + ("__ipow__".into(), (ndarray_binop_fun_ty, true)), ]), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -243,7 +290,16 @@ impl TopLevelComposer { ]), }); + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None); + let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None); + unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap(); + unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap(); + + let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None); + unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap(); let primitives = PrimitiveStore { int32, diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 6efad66f0..498e3f157 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [124]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 5e69d6648..8454bfbbb 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar113]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar113\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 7d2479351..ee506c167 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [126]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [131]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index e32ae8098..16159e46a 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar112, typevar113]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar112\", \"typevar113\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 59a27702d..dfda3a845 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [132]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [140]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index ec0c064cc..bfd137d02 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -453,8 +453,8 @@ pub fn typeof_binop( } Operator::LShift - | Operator::RShift - | Operator::BitOr + | Operator::RShift => lhs, + Operator::BitOr | Operator::BitXor | Operator::BitAnd => { if unifier.unioned(lhs, rhs) { @@ -474,18 +474,21 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, uint32: uint32_t, uint64: uint64_t, + ndarray: ndarray_t, .. } = *store; + let size_t = store.usize(); /* int ======== */ for t in [int32_t, int64_t, uint32_t, uint64_t] { - impl_basic_arithmetic(unifier, store, t, &[t], Some(t)); - impl_pow(unifier, store, t, &[t], Some(t)); + let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None); + impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None); + impl_pow(unifier, store, t, &[t, ndarray_int_t], None); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); - impl_div(unifier, store, t, &[t], Some(float_t)); - impl_floordiv(unifier, store, t, &[t], Some(t)); - impl_mod(unifier, store, t, &[t], Some(t)); + impl_div(unifier, store, t, &[t, ndarray_int_t], None); + impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None); + impl_mod(unifier, store, t, &[t, ndarray_int_t], None); impl_invert(unifier, store, t, Some(t)); impl_not(unifier, store, t, Some(bool_t)); impl_comparison(unifier, store, t, &[t], Some(bool_t)); @@ -496,11 +499,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie } /* float ======== */ - impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t)); - impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t)); - impl_div(unifier, store, float_t, &[float_t], Some(float_t)); - impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t)); - impl_mod(unifier, store, float_t, &[float_t], Some(float_t)); + let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); + let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None); + impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None); + impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None); + impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None); impl_sign(unifier, store, float_t, Some(float_t)); impl_not(unifier, store, float_t, Some(bool_t)); impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t)); @@ -509,4 +514,15 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* bool ======== */ impl_not(unifier, store, bool_t, Some(bool_t)); impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t)); + + /* ndarray ===== */ + let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); + let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0)); + let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t); + let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t); + impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None); + impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); + impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 5fb89e0b2..1b28a2451 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -1203,8 +1203,11 @@ impl<'a> Inferencer<'a> { right: &ast::Expr>, is_aug_assign: bool, ) -> InferenceResult { + let left_ty = left.custom.unwrap(); + let right_ty = right.custom.unwrap(); + let method = if let TypeEnum::TObj { fields, .. } = - self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() + self.unifier.get_ty_immutable(left_ty).as_ref() { let (binop_name, binop_assign_name) = ( binop_name(op).into(), @@ -1219,12 +1222,26 @@ impl<'a> Inferencer<'a> { } else { binop_name(op).into() }; + + let ret = if is_aug_assign { + // The type of augmented assignment operator should never change + Some(left_ty) + } else { + typeof_binop( + self.unifier, + self.primitives, + op, + left_ty, + right_ty, + ).map_err(|e| HashSet::from([format!("{e} (at {location})")]))? + }; + self.build_method_call( location, method, - left.custom.unwrap(), - vec![right.custom.unwrap()], - None, + left_ty, + vec![right_ty], + ret, ) } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 8c72cfb13..e6fe80091 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -135,10 +135,15 @@ impl TestEnvironment { fields: HashMap::new(), params: VarMap::new(), }); + let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None); + let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, fields: HashMap::new(), - params: VarMap::new(), + params: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), }); let primitives = PrimitiveStore { int32, diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 3ad2b3983..4d6098a3e 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -774,12 +774,8 @@ impl Unifier { // If the types don't match, try to implicitly promote integers if !self.unioned(ty, value_ty) { - let num_val = match *value { - SymbolValue::I32(v) => v as i128, - SymbolValue::I64(v) => v as i128, - SymbolValue::U32(v) => v as i128, - SymbolValue::U64(v) => v as i128, - _ => return Self::incompatible_types(a, b), + let Ok(num_val) = i128::try_from(value.clone()) else { + return Self::incompatible_types(a, b) }; let can_convert = if self.unioned(ty, primitives.int32) { diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 353ebe5c6..87f0b3615 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -6,6 +6,19 @@ def output_int32(x: int32): def output_float64(x: float): ... +def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]): + for i in range(len(n)): + output_int32(n[i]) + +def output_ndarray_float_1(n: ndarray[float, Literal[1]]): + for i in range(len(n)): + output_float64(n[i]) + +def output_ndarray_float_2(n: ndarray[float, Literal[2]]): + for r in range(len(n)): + for c in range(len(n[r])): + output_float64(n[r][c]) + def consume_ndarray_1(n: ndarray[float, Literal[1]]): pass @@ -19,53 +32,381 @@ def test_ndarray_empty(): def test_ndarray_zeros(): n: ndarray[float, 1] = np_zeros([1]) - output_float64(n[0]) + output_ndarray_float_1(n) def test_ndarray_ones(): n: ndarray[float, 1] = np_ones([1]) - output_float64(n[0]) + output_ndarray_float_1(n) def test_ndarray_full(): n_float: ndarray[float, 1] = np_full([1], 2.0) - output_float64(n_float[0]) + output_ndarray_float_1(n_float) n_i32: ndarray[int32, 1] = np_full([1], 2) - output_int32(n_i32[0]) + output_ndarray_int32_1(n_i32) def test_ndarray_eye(): n: ndarray[float, 2] = np_eye(2) - n0: ndarray[float, 1] = n[0] - v: float = n0[0] - output_float64(v) + output_ndarray_float_2(n) def test_ndarray_identity(): n: ndarray[float, 2] = np_identity(2) - output_float64(n[0][0]) - output_float64(n[0][1]) - output_float64(n[1][0]) - output_float64(n[1][1]) + output_ndarray_float_2(n) def test_ndarray_fill(): n: ndarray[float, 2] = np_empty([2, 2]) n.fill(1.0) - output_float64(n[0][0]) - output_float64(n[0][1]) - output_float64(n[1][0]) - output_float64(n[1][1]) + output_ndarray_float_2(n) def test_ndarray_copy(): x: ndarray[float, 2] = np_identity(2) y = x.copy() x.fill(0.0) - output_float64(x[0][0]) - output_float64(x[0][1]) - output_float64(x[1][0]) - output_float64(x[1][1]) + output_ndarray_float_2(x) + output_ndarray_float_2(y) - output_float64(y[0][0]) - output_float64(y[0][1]) - output_float64(y[1][0]) - output_float64(y[1][1]) +def test_ndarray_add(): + x = np_identity(2) + y = x + np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x + np_ones([2]) + y = x + np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 1.0 + x + y = 1.0 + x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_add_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x + 1.0 + y = x + 1.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_iadd(): + x = np_identity(2) + x += np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_iadd_broadcast(): + x = np_identity(2) + x += np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_iadd_broadcast_scalar(): + x = np_identity(2) + x += 1.0 + + output_ndarray_float_2(x) + +def test_ndarray_sub(): + x = np_ones([2, 2]) + y = x - np_identity(2) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x - np_ones([2]) + y = x - np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 1.0 - x + y = 1.0 - x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_sub_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x - 1 + y = x - 1.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_isub(): + x = np_ones([2, 2]) + x -= np_identity(2) + + output_ndarray_float_2(x) + +def test_ndarray_isub_broadcast(): + x = np_identity(2) + x -= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_isub_broadcast_scalar(): + x = np_identity(2) + x -= 1.0 + + output_ndarray_float_2(x) + +def test_ndarray_mul(): + x = np_ones([2, 2]) + y = x * np_identity(2) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x * np_ones([2]) + y = x * np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 2.0 * x + y = 2.0 * x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mul_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x * 2.0 + y = x * 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_imul(): + x = np_ones([2, 2]) + x *= np_identity(2) + + output_ndarray_float_2(x) + +def test_ndarray_imul_broadcast(): + x = np_identity(2) + x *= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_imul_broadcast_scalar(): + x = np_identity(2) + x *= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_truediv(): + x = np_identity(2) + y = x / np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x / np_ones([2]) + y = x / np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 / x + y = 2.0 / x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_truediv_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x / 2.0 + y = x / 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_itruediv(): + x = np_identity(2) + x /= np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_itruediv_broadcast(): + x = np_identity(2) + x /= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_itruediv_broadcast_scalar(): + x = np_identity(2) + x /= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_floordiv(): + x = np_identity(2) + y = x // np_ones([2, 2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x // np_ones([2]) + y = x // np_ones([2]) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 // x + y = 2.0 // x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_floordiv_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x // 2.0 + y = x // 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_ifloordiv(): + x = np_identity(2) + x //= np_ones([2, 2]) + + output_ndarray_float_2(x) + +def test_ndarray_ifloordiv_broadcast(): + x = np_identity(2) + x //= np_ones([2]) + + output_ndarray_float_2(x) + +def test_ndarray_ifloordiv_broadcast_scalar(): + x = np_identity(2) + x //= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_mod(): + x = np_identity(2) + y = x % np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x % np_ones([2]) + y = x % np_full([2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast_lhs_scalar(): + x = np_ones([2, 2]) + # y: ndarray[float, 2] = 2.0 % x + y = 2.0 % x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_mod_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x % 2.0 + y = x % 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_imod(): + x = np_identity(2) + x %= np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_imod_broadcast(): + x = np_identity(2) + x %= np_full([2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_imod_broadcast_scalar(): + x = np_identity(2) + x %= 2.0 + + output_ndarray_float_2(x) + +def test_ndarray_pow(): + x = np_identity(2) + y = x ** np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast(): + x = np_identity(2) + # y: ndarray[float, 2] = x ** np_full([2], 2.0) + y = x ** np_full([2], 2.0) + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast_lhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = 2.0 ** x + y = 2.0 ** x + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_pow_broadcast_rhs_scalar(): + x = np_identity(2) + # y: ndarray[float, 2] = x % 2.0 + y = x ** 2.0 + + output_ndarray_float_2(x) + output_ndarray_float_2(y) + +def test_ndarray_ipow(): + x = np_identity(2) + x **= np_full([2, 2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_ipow_broadcast(): + x = np_identity(2) + x **= np_full([2], 2.0) + + output_ndarray_float_2(x) + +def test_ndarray_ipow_broadcast_scalar(): + x = np_identity(2) + x **= 2.0 + + output_ndarray_float_2(x) def run() -> int32: test_ndarray_ctor() @@ -77,5 +418,54 @@ def run() -> int32: test_ndarray_identity() test_ndarray_fill() test_ndarray_copy() + test_ndarray_add() + test_ndarray_add_broadcast() + test_ndarray_add_broadcast_lhs_scalar() + test_ndarray_add_broadcast_rhs_scalar() + test_ndarray_iadd() + test_ndarray_iadd_broadcast() + test_ndarray_iadd_broadcast_scalar() + test_ndarray_sub() + test_ndarray_sub_broadcast() + test_ndarray_sub_broadcast_lhs_scalar() + test_ndarray_sub_broadcast_rhs_scalar() + test_ndarray_isub() + test_ndarray_isub_broadcast() + test_ndarray_isub_broadcast_scalar() + test_ndarray_mul() + test_ndarray_mul_broadcast() + test_ndarray_mul_broadcast_lhs_scalar() + test_ndarray_mul_broadcast_rhs_scalar() + test_ndarray_imul() + test_ndarray_imul_broadcast() + test_ndarray_imul_broadcast_scalar() + test_ndarray_truediv() + test_ndarray_truediv_broadcast() + test_ndarray_truediv_broadcast_lhs_scalar() + test_ndarray_truediv_broadcast_rhs_scalar() + test_ndarray_itruediv() + test_ndarray_itruediv_broadcast() + test_ndarray_itruediv_broadcast_scalar() + test_ndarray_floordiv() + test_ndarray_floordiv_broadcast() + test_ndarray_floordiv_broadcast_lhs_scalar() + test_ndarray_floordiv_broadcast_rhs_scalar() + test_ndarray_ifloordiv() + test_ndarray_ifloordiv_broadcast() + test_ndarray_ifloordiv_broadcast_scalar() + test_ndarray_mod() + test_ndarray_mod_broadcast() + test_ndarray_mod_broadcast_lhs_scalar() + test_ndarray_mod_broadcast_rhs_scalar() + test_ndarray_imod() + test_ndarray_imod_broadcast() + test_ndarray_imod_broadcast_scalar() + test_ndarray_pow() + test_ndarray_pow_broadcast() + test_ndarray_pow_broadcast_lhs_scalar() + test_ndarray_pow_broadcast_rhs_scalar() + test_ndarray_ipow() + test_ndarray_ipow_broadcast() + test_ndarray_ipow_broadcast_scalar() return 0