diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index ccba682..5d7621c 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -5,7 +5,7 @@ use nac3core::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -654,7 +654,7 @@ impl InnerResolver { } } (TypeEnum::TObj { obj_id, .. }, false) if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (ty, ndims) = unpack_ndarray_tvars(unifier, extracted_ty); + let (ty, ndims) = unpack_ndarray_var_tys(unifier, extracted_ty); let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { assert!(matches!( diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 1bbb79d..e537d06 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -24,7 +24,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelDef, }, typecheck::{ @@ -1131,7 +1131,7 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Ok(Some(res.into())) } else if ty1 == ty2 && matches!(&*ctx.unifier.get_ty(ty1), TypeEnum::TObj { obj_id, .. } if obj_id == &PRIMITIVE_DEF_IDS.ndarray) { let llvm_usize = generator.get_size_type(ctx.ctx); - let (ndarray_dtype, _) = unpack_ndarray_tvars(&mut ctx.unifier, ty1); + let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1); let left_val = NDArrayValue::from_ptr_val( left_val.into_pointer_value(), diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index c528eab..c07185e 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -2,7 +2,7 @@ use crate::{ symbol_resolver::{StaticValue, SymbolResolver}, toplevel::{ helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_tvars, + numpy::unpack_ndarray_var_tys, TopLevelContext, TopLevelDef, }, @@ -451,7 +451,7 @@ fn get_llvm_type<'ctx, G: CodeGenerator + ?Sized>( TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { let llvm_usize = generator.get_size_type(ctx); - let (dtype, _) = unpack_ndarray_tvars(unifier, ty); + let (dtype, _) = unpack_ndarray_var_tys(unifier, ty); let element_type = get_llvm_type( ctx, module, diff --git a/nac3core/src/codegen/numpy.rs b/nac3core/src/codegen/numpy.rs index 624e535..e2c648f 100644 --- a/nac3core/src/codegen/numpy.rs +++ b/nac3core/src/codegen/numpy.rs @@ -27,7 +27,7 @@ use crate::{ symbol_resolver::ValueEnum, toplevel::{ DefinitionId, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, }, typecheck::typedef::{FunSignature, Type}, }; @@ -861,7 +861,7 @@ pub fn gen_ndarray_copy<'ctx>( let llvm_usize = generator.get_size_type(context.ctx); let this_ty = obj.as_ref().unwrap().0; - let (this_elem_ty, _) = unpack_ndarray_tvars(&mut context.unifier, this_ty); + let (this_elem_ty, _) = unpack_ndarray_var_tys(&mut context.unifier, this_ty); let this_arg = obj .as_ref() .unwrap() diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 2e6e8a6..fd5d51a 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -13,7 +13,7 @@ use crate::{ toplevel::{ DefinitionId, helper::PRIMITIVE_DEF_IDS, - numpy::unpack_ndarray_tvars, + numpy::unpack_ndarray_var_tys, TopLevelDef, }, typecheck::typedef::{FunSignature, Type, TypeEnum}, @@ -251,7 +251,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let ty = match &*ctx.unifier.get_ty_immutable(target.custom.unwrap()) { TypeEnum::TList { ty } => *ty, TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - unpack_ndarray_tvars(&mut ctx.unifier, target.custom.unwrap()).0 + unpack_ndarray_var_tys(&mut ctx.unifier, target.custom.unwrap()).0 } _ => unreachable!(), }; diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 20cefc7..bd8e358 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -347,6 +347,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); + let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap(); + let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap(); + let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap(); + let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap(); + let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap(); + let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap(); + let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap(); + let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap(); + let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap(); + let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap(); + let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap(); + let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap(); + let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap(); + let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap(); let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( @@ -524,6 +538,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { methods: vec![ ("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)), + ("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)), + ("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)), + ("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)), + ("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)), + ("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)), + ("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)), + ("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)), + ("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)), + ("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)), + ("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)), + ("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)), + ("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)), + ("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)), + ("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)), ], ancestors: Vec::default(), constructor: None, @@ -562,6 +590,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__add__".into(), + simple_name: "__add__".into(), + signature: ndarray_add_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__sub__".into(), + simple_name: "__sub__".into(), + signature: ndarray_sub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mul__".into(), + simple_name: "__mul__".into(), + signature: ndarray_mul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__truediv__".into(), + simple_name: "__truediv__".into(), + signature: ndarray_truediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__floordiv__".into(), + simple_name: "__floordiv__".into(), + signature: ndarray_floordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__mod__".into(), + simple_name: "__mod__".into(), + signature: ndarray_mod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__pow__".into(), + simple_name: "__pow__".into(), + signature: ndarray_pow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__iadd__".into(), + simple_name: "__iadd__".into(), + signature: ndarray_iadd_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__isub__".into(), + simple_name: "__isub__".into(), + signature: ndarray_isub_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imul__".into(), + simple_name: "__imul__".into(), + signature: ndarray_imul_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__itruediv__".into(), + simple_name: "__itruediv__".into(), + signature: ndarray_itruediv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ifloordiv__".into(), + simple_name: "__ifloordiv__".into(), + signature: ndarray_ifloordiv_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__imod__".into(), + simple_name: "__imod__".into(), + signature: ndarray_imod_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "ndarray.__ipow__".into(), + simple_name: "__ipow__".into(), + signature: ndarray_ipow_ty.0, + var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id], + instance_to_symbol: HashMap::default(), + instance_to_stmt: HashMap::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |_, _, _, _, _| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 8a3908e..ea42b92 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -1,6 +1,7 @@ use std::convert::TryInto; use crate::symbol_resolver::SymbolValue; +use crate::toplevel::numpy::subst_ndarray_tvars; use crate::typecheck::typedef::{Mapping, VarMap}; use nac3parser::ast::{Constant, Location}; @@ -226,11 +227,57 @@ impl TopLevelComposer { (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), ]), })); + let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_binop_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_binop_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); + let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None); + let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "other".into(), + ty: ndarray_truediv_fun_other_ty.0, + default_value: None, + }, + ], + ret: ndarray_truediv_fun_ret_ty.0, + vars: VarMap::from([ + (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), + (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), + ]), + })); let ndarray = unifier.add_ty(TypeEnum::TObj { obj_id: PRIMITIVE_DEF_IDS.ndarray, fields: Mapping::from([ ("copy".into(), (ndarray_copy_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)), + ("__add__".into(), (ndarray_binop_fun_ty, true)), + ("__sub__".into(), (ndarray_binop_fun_ty, true)), + ("__mul__".into(), (ndarray_binop_fun_ty, true)), + ("__truediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__floordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__mod__".into(), (ndarray_binop_fun_ty, true)), + ("__pow__".into(), (ndarray_binop_fun_ty, true)), + ("__iadd__".into(), (ndarray_binop_fun_ty, true)), + ("__isub__".into(), (ndarray_binop_fun_ty, true)), + ("__imul__".into(), (ndarray_binop_fun_ty, true)), + ("__itruediv__".into(), (ndarray_truediv_fun_ty, true)), + ("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)), + ("__imod__".into(), (ndarray_binop_fun_ty, true)), + ("__ipow__".into(), (ndarray_binop_fun_ty, true)), ]), params: VarMap::from([ (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), @@ -239,6 +286,12 @@ impl TopLevelComposer { }); unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_binop_fun_other_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap(); + + let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None); + unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap(); + unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap(); let primitives = PrimitiveStore { int32, diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index d322519..aee0904 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -19,13 +19,30 @@ pub fn make_ndarray_ty( dtype: Option, ndims: Option, ) -> Type { - let ndarray = primitives.ndarray; + subst_ndarray_tvars(unifier, primitives.ndarray, dtype, ndims) +} +/// Substitutes type variables in `ndarray`. +/// +/// * `dtype` - The element type of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +/// * `ndims` - The number of dimensions of the `ndarray`, or [`None`] if the type variable is not +/// specialized. +pub fn subst_ndarray_tvars( + unifier: &mut Unifier, + ndarray: Type, + dtype: Option, + ndims: Option, +) -> Type { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; debug_assert_eq!(*obj_id, PRIMITIVE_DEF_IDS.ndarray); + if dtype.is_none() && ndims.is_none() { + return ndarray + } + let tvar_ids = params.iter() .map(|(obj_id, _)| *obj_id) .collect_vec(); @@ -42,12 +59,10 @@ pub fn make_ndarray_ty( unifier.subst(ndarray, &tvar_subst).unwrap_or(ndarray) } -/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to -/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. -pub fn unpack_ndarray_tvars( +fn unpack_ndarray_tvars( unifier: &mut Unifier, ndarray: Type, -) -> (Type, Type) { +) -> Vec<(u32, Type)> { let TypeEnum::TObj { obj_id, params, .. } = &*unifier.get_ty_immutable(ndarray) else { panic!("Expected `ndarray` to be TObj, but got {}", unifier.stringify(ndarray)) }; @@ -56,7 +71,33 @@ pub fn unpack_ndarray_tvars( params.iter() .sorted_by_key(|(obj_id, _)| *obj_id) - .map(|(_, ty)| *ty) + .map(|(var_id, ty)| (*var_id, *ty)) + .collect_vec() +} + +/// Unpacks the type variable IDs of `ndarray` into a tuple. The elements of the tuple corresponds +/// to `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` +/// respectively. +pub fn unpack_ndarray_var_ids( + unifier: &mut Unifier, + ndarray: Type, +) -> (u32, u32) { + unpack_ndarray_tvars(unifier, ndarray) + .into_iter() + .map(|v| v.0) + .collect_tuple() + .unwrap() +} + +/// Unpacks the type variables of `ndarray` into a tuple. The elements of the tuple corresponds to +/// `dtype` (the element type) and `ndims` (the number of dimensions) of the `ndarray` respectively. +pub fn unpack_ndarray_var_tys( + unifier: &mut Unifier, + ndarray: Type, +) -> (Type, Type) { + unpack_ndarray_tvars(unifier, ndarray) + .into_iter() + .map(|v| v.1) .collect_tuple() .unwrap() } diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index d25fff5..53be568 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -1,3 +1,4 @@ +use crate::toplevel::numpy::make_ndarray_ty; use crate::typecheck::{ type_inferencer::*, typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier, VarMap}, @@ -234,8 +235,14 @@ pub fn impl_bitwise_shift(unifier: &mut Unifier, store: &PrimitiveStore, ty: Typ } /// `Div` -pub fn impl_div(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type, other_ty: &[Type]) { - impl_binop(unifier, store, ty, other_ty, store.float, &[Operator::Div]); +pub fn impl_div( + unifier: &mut Unifier, + store: &PrimitiveStore, + ty: Type, + other_ty: &[Type], + ret_ty: Type, +) { + impl_binop(unifier, store, ty, other_ty, ret_ty, &[Operator::Div]); } /// `FloorDiv` @@ -299,6 +306,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, uint32: uint32_t, uint64: uint64_t, + ndarray: ndarray_t, .. } = *store; @@ -308,7 +316,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_pow(unifier, store, t, &[t], t); impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_shift(unifier, store, t); - impl_div(unifier, store, t, &[t]); + impl_div(unifier, store, t, &[t], float_t); impl_floordiv(unifier, store, t, &[t], t); impl_mod(unifier, store, t, &[t], t); impl_invert(unifier, store, t); @@ -323,7 +331,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* float ======== */ impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); impl_pow(unifier, store, float_t, &[int32_t, float_t], float_t); - impl_div(unifier, store, float_t, &[float_t]); + impl_div(unifier, store, float_t, &[float_t], float_t); impl_floordiv(unifier, store, float_t, &[float_t], float_t); impl_mod(unifier, store, float_t, &[float_t], float_t); impl_sign(unifier, store, float_t); @@ -334,4 +342,12 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie /* bool ======== */ impl_not(unifier, store, bool_t); impl_eq(unifier, store, bool_t); + + /* ndarray ===== */ + let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None); + impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); + impl_pow(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); + impl_div(unifier, store, ndarray_t, &[ndarray_t], ndarray_float_t); + impl_floordiv(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); + impl_mod(unifier, store, ndarray_t, &[ndarray_t], ndarray_t); } diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index a6b72bb..22812ba 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -9,7 +9,7 @@ use crate::{ symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::{ helper::PRIMITIVE_DEF_IDS, - numpy::{make_ndarray_ty, unpack_ndarray_tvars}, + numpy::{make_ndarray_ty, unpack_ndarray_var_tys}, TopLevelContext, }, }; @@ -1334,7 +1334,7 @@ impl<'a> Inferencer<'a> { let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims)) } @@ -1347,7 +1347,7 @@ impl<'a> Inferencer<'a> { ExprKind::Constant { value: ast::Constant::Int(val), .. } => { match &*self.unifier.get_ty(value.custom.unwrap()) { TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.infer_subscript_ndarray(value, ty, ndims) } _ => { @@ -1379,7 +1379,7 @@ impl<'a> Inferencer<'a> { Ok(ty) } TypeEnum::TObj { obj_id, .. } if *obj_id == PRIMITIVE_DEF_IDS.ndarray => { - let (_, ndims) = unpack_ndarray_tvars(self.unifier, value.custom.unwrap()); + let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap()); self.constrain(slice.custom.unwrap(), self.primitives.usize(), &slice.location)?; self.infer_subscript_ndarray(value, ty, ndims) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 353ebe5..5d91541 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -67,6 +67,167 @@ def test_ndarray_copy(): output_float64(y[1][0]) output_float64(y[1][1]) +def test_ndarray_add(): + x = np_identity(2) + y = x + np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_iadd(): + x = np_identity(2) + x += np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_sub(): + x = np_ones([2, 2]) + y = x - np_identity(2) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_isub(): + x = np_ones([2, 2]) + x -= np_identity(2) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_mul(): + x = np_ones([2, 2]) + y = x * np_identity(2) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_imul(): + x = np_ones([2, 2]) + x *= np_identity(2) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_truediv(): + x = np_identity(2) + y = x / np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_itruediv(): + x = np_identity(2) + x /= np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_floordiv(): + x = np_identity(2) + y = x // np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_ifloordiv(): + x = np_identity(2) + x //= np_ones([2, 2]) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_mod(): + x = np_identity(2) + y = x % np_full([2, 2], 2.0) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_imod(): + x = np_identity(2) + x %= np_full([2, 2], 2.0) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + +def test_ndarray_pow(): + x = np_identity(2) + y = x ** np_full([2, 2], 2.0) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + + output_float64(y[0][0]) + output_float64(y[0][1]) + output_float64(y[1][0]) + output_float64(y[1][1]) + +def test_ndarray_ipow(): + x = np_identity(2) + x **= np_full([2, 2], 2.0) + + output_float64(x[0][0]) + output_float64(x[0][1]) + output_float64(x[1][0]) + output_float64(x[1][1]) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() @@ -77,5 +238,17 @@ def run() -> int32: test_ndarray_identity() test_ndarray_fill() test_ndarray_copy() + test_ndarray_add() + test_ndarray_iadd() + test_ndarray_sub() + test_ndarray_isub() + test_ndarray_mul() + test_ndarray_imul() + test_ndarray_truediv() + test_ndarray_itruediv() + test_ndarray_floordiv() + test_ndarray_ifloordiv() + test_ndarray_mod() + test_ndarray_imod() return 0