1
0
forked from M-Labs/nac3

core: Implement elementwise binary operators

Including immediate variants of these operators.
This commit is contained in:
David Mak 2024-03-13 11:16:23 +08:00
parent 3540d0ab29
commit 6af13a8261
16 changed files with 1049 additions and 56 deletions

View File

@ -11,7 +11,7 @@ use crate::codegen::{
stmt::gen_for_callback_incrementing, stmt::gen_for_callback_incrementing,
}; };
/// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of /// An LLVM value that is array-like, i.e. it contains a contiguous, sequenced collection of
/// elements. /// elements.
pub trait ArrayLikeValue<'ctx> { pub trait ArrayLikeValue<'ctx> {
/// Returns the element type of this array-like value. /// Returns the element type of this array-like value.

View File

@ -17,6 +17,7 @@ use crate::{
get_llvm_abi_type, get_llvm_abi_type,
irrt::*, irrt::*,
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi}, llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
numpy,
stmt::{gen_raise, gen_var}, stmt::{gen_raise, gen_var},
CodeGenContext, CodeGenTask, CodeGenContext, CodeGenTask,
}, },
@ -24,7 +25,7 @@ use crate::{
toplevel::{ toplevel::{
DefinitionId, DefinitionId,
helper::PRIMITIVE_DEF_IDS, helper::PRIMITIVE_DEF_IDS,
numpy::make_ndarray_ty, numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
TopLevelDef, TopLevelDef,
}, },
typecheck::{ typecheck::{
@ -1129,6 +1130,78 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
Some("f_pow_i") Some("f_pow_i")
); );
Ok(Some(res.into())) Ok(Some(res.into()))
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
let llvm_usize = generator.get_size_type(ctx.ctx);
let is_ndarray1 = ty1.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
let is_ndarray2 = ty2.obj_id(&ctx.unifier)
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
if is_ndarray1 && is_ndarray2 {
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
let left_val = NDArrayValue::from_ptr_val(
left_val.into_pointer_value(),
llvm_usize,
None
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype1,
if is_aug_assign { Some(left_val) } else { None },
(left_val.as_ptr_value().into(), false),
(right_val, false),
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(ndarray_dtype1), lhs),
op,
(&Some(ndarray_dtype2), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
},
)?;
Ok(Some(res.as_ptr_value().into()))
} else {
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
&mut ctx.unifier,
if is_ndarray1 { ty1 } else { ty2 },
);
let ndarray_val = NDArrayValue::from_ptr_val(
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
llvm_usize,
None,
);
let res = numpy::ndarray_elementwise_binop_impl(
generator,
ctx,
ndarray_dtype,
if is_aug_assign { Some(ndarray_val) } else { None },
(left_val, !is_ndarray1),
(right_val, !is_ndarray2),
|generator, ctx, (lhs, rhs)| {
gen_binop_expr_with_values(
generator,
ctx,
(&Some(ndarray_dtype), lhs),
op,
(&Some(ndarray_dtype), rhs),
ctx.current_loc,
is_aug_assign,
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
},
)?;
Ok(Some(res.as_ptr_value().into()))
}
} else { } else {
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap()); let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else { let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {

View File

@ -18,6 +18,8 @@ use crate::{
CodeGenContext, CodeGenContext,
CodeGenerator, CodeGenerator,
irrt::{ irrt::{
call_ndarray_calc_broadcast,
call_ndarray_calc_broadcast_index,
call_ndarray_calc_nd_indices, call_ndarray_calc_nd_indices,
call_ndarray_calc_size, call_ndarray_calc_size,
}, },
@ -338,6 +340,98 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
) )
} }
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
/// the target `ndarray`.
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
target: NDArrayValue<'ctx>,
source: NDArrayValue<'ctx>,
) {
let array_ndims = source.load_ndims(ctx);
let broadcast_size = target.load_ndims(ctx);
ctx.make_assert(
generator,
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
"0:ValueError",
"operands cannot be broadcast together",
[None, None, None],
ctx.current_loc,
);
}
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
/// with broadcast-compatible shapes.
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
res: NDArrayValue<'ctx>,
lhs: (BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator + ?Sized,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs;
assert!(!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type());
// Assert that all ndarray operands are broadcastable to the target size
if !lhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
}
if !rhs_scalar {
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
}
ndarray_fill_indexed(
generator,
ctx,
res,
|generator, ctx, idx| {
let lhs_elem = if lhs_scalar {
lhs_val
} else {
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx);
unsafe {
lhs.data().get_unchecked(ctx, generator, lhs_idx, None)
}
};
let rhs_elem = if rhs_scalar {
rhs_val
} else {
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx);
unsafe {
rhs.data().get_unchecked(ctx, generator, rhs_idx, None)
}
};
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
value_fn(generator, ctx, (lhs_elem, rhs_elem))
},
)?;
Ok(res)
}
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. /// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
/// ///
/// * `elem_ty` - The element type of the `NDArray`. /// * `elem_ty` - The element type of the `NDArray`.
@ -562,6 +656,107 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
Ok(ndarray) Ok(ndarray)
} }
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
///
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
/// `value_fn` arguments tuple for all output elements.
///
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
/// broadcast to all elements).
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
/// written to a new `ndarray`.
/// * `value_fn` - Function mapping the two input elements into the result.
///
/// # Panic
///
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
res: Option<NDArrayValue<'ctx>>,
lhs: (BasicValueEnum<'ctx>, bool),
rhs: (BasicValueEnum<'ctx>, bool),
value_fn: ValueFn,
) -> Result<NDArrayValue<'ctx>, String>
where
G: CodeGenerator,
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
{
let llvm_usize = generator.get_size_type(ctx.ctx);
let (lhs_val, lhs_scalar) = lhs;
let (rhs_val, rhs_scalar) = rhs;
assert!(!(lhs_scalar && rhs_scalar),
"One of the operands must be a ndarray instance: `{}`, `{}`",
lhs_val.get_type(),
rhs_val.get_type());
let ndarray = res.unwrap_or_else(|| {
if lhs_scalar && rhs_scalar {
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray_dims,
|generator, ctx, v| {
Ok(v.size(ctx, generator))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
} else {
let ndarray = NDArrayValue::from_ptr_val(
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
llvm_usize,
None,
);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&ndarray,
|_, ctx, v| {
Ok(v.load_ndims(ctx))
},
|generator, ctx, v, idx| {
unsafe {
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
}
},
).unwrap()
}
});
ndarray_broadcast_fill(
generator,
ctx,
ndarray,
lhs,
rhs,
|generator, ctx, elems| {
value_fn(generator, ctx, elems)
},
)?;
Ok(ndarray)
}
/// Generates LLVM IR for `ndarray.empty`. /// Generates LLVM IR for `ndarray.empty`.
pub fn gen_ndarray_empty<'ctx>( pub fn gen_ndarray_empty<'ctx>(
context: &mut CodeGenContext<'ctx, '_>, context: &mut CodeGenContext<'ctx, '_>,

View File

@ -546,7 +546,7 @@ pub fn gen_for_callback<'ctx, 'a, G, I, InitFn, CondFn, BodyFn, UpdateFn>(
/// body(x); /// body(x);
/// } /// }
/// ``` /// ```
/// ///
/// * `init_val` - The initial value of the loop variable. The type of this value will also be used /// * `init_val` - The initial value of the loop variable. The type of this value will also be used
/// as the type of the loop variable. /// as the type of the loop variable.
/// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum /// * `max_val` - A tuple containing the maximum value of the loop variable, and whether the maximum

View File

@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
Some("N".into()), Some("N".into()),
None, None,
); );
let size_t = primitives.0.usize();
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect(); let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
let exception_fields = vec![ let exception_fields = vec![
("__name__".into(), int32, true), ("__name__".into(), int32, true),
@ -345,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
.nth(1) .nth(1)
.map(|(var_id, ty)| (*ty, *var_id)) .map(|(var_id, ty)| (*ty, *var_id))
.unwrap(); .unwrap();
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
size_t,
Some("ndarray_ndims".into()),
None,
);
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap(); let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap(); let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap();
let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap();
let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap();
let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap();
let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap();
let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap();
let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap();
let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap();
let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap();
let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap();
let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap();
let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap();
let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap();
let top_level_def_list = vec![ let top_level_def_list = vec![
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
@ -524,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
methods: vec![ methods: vec![
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)), ("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)), ("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)),
("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)),
("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)),
("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)),
("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)),
("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)),
("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)),
("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)),
("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)),
("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)),
("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)),
("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)),
("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)),
("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)),
], ],
ancestors: Vec::default(), ancestors: Vec::default(),
constructor: None, constructor: None,
@ -562,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
)))), )))),
loc: None, loc: None,
})), })),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__add__".into(),
simple_name: "__add__".into(),
signature: ndarray_add_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__sub__".into(),
simple_name: "__sub__".into(),
signature: ndarray_sub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mul__".into(),
simple_name: "__mul__".into(),
signature: ndarray_mul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__truediv__".into(),
simple_name: "__truediv__".into(),
signature: ndarray_truediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__floordiv__".into(),
simple_name: "__floordiv__".into(),
signature: ndarray_floordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__mod__".into(),
simple_name: "__mod__".into(),
signature: ndarray_mod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__pow__".into(),
simple_name: "__pow__".into(),
signature: ndarray_pow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__iadd__".into(),
simple_name: "__iadd__".into(),
signature: ndarray_iadd_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__isub__".into(),
simple_name: "__isub__".into(),
signature: ndarray_isub_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imul__".into(),
simple_name: "__imul__".into(),
signature: ndarray_imul_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__itruediv__".into(),
simple_name: "__itruediv__".into(),
signature: ndarray_itruediv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ifloordiv__".into(),
simple_name: "__ifloordiv__".into(),
signature: ndarray_ifloordiv_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__imod__".into(),
simple_name: "__imod__".into(),
signature: ndarray_imod_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function {
name: "ndarray.__ipow__".into(),
simple_name: "__ipow__".into(),
signature: ndarray_ipow_ty.0,
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
instance_to_symbol: HashMap::default(),
instance_to_stmt: HashMap::default(),
resolver: None,
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|_, _, _, _, _| {
unreachable!("handled in gen_expr")
},
)))),
loc: None,
})),
Arc::new(RwLock::new(TopLevelDef::Function { Arc::new(RwLock::new(TopLevelDef::Function {
name: "int32".into(), name: "int32".into(),
simple_name: "int32".into(), simple_name: "int32".into(),

View File

@ -1,6 +1,7 @@
use std::convert::TryInto; use std::convert::TryInto;
use crate::symbol_resolver::SymbolValue; use crate::symbol_resolver::SymbolValue;
use crate::toplevel::numpy::subst_ndarray_tvars;
use crate::typecheck::typedef::{Mapping, VarMap}; use crate::typecheck::typedef::{Mapping, VarMap};
use nac3parser::ast::{Constant, Location}; use nac3parser::ast::{Constant, Location};
@ -231,11 +232,57 @@ impl TopLevelComposer {
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0), (ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]), ]),
})); }));
let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_binop_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_binop_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None);
let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![
FuncArg {
name: "other".into(),
ty: ndarray_truediv_fun_other_ty.0,
default_value: None,
},
],
ret: ndarray_truediv_fun_ret_ty.0,
vars: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}));
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray, obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: Mapping::from([ fields: Mapping::from([
("copy".into(), (ndarray_copy_fun_ty, true)), ("copy".into(), (ndarray_copy_fun_ty, true)),
("fill".into(), (ndarray_fill_fun_ty, true)), ("fill".into(), (ndarray_fill_fun_ty, true)),
("__add__".into(), (ndarray_binop_fun_ty, true)),
("__sub__".into(), (ndarray_binop_fun_ty, true)),
("__mul__".into(), (ndarray_binop_fun_ty, true)),
("__truediv__".into(), (ndarray_truediv_fun_ty, true)),
("__floordiv__".into(), (ndarray_binop_fun_ty, true)),
("__mod__".into(), (ndarray_binop_fun_ty, true)),
("__pow__".into(), (ndarray_binop_fun_ty, true)),
("__iadd__".into(), (ndarray_binop_fun_ty, true)),
("__isub__".into(), (ndarray_binop_fun_ty, true)),
("__imul__".into(), (ndarray_binop_fun_ty, true)),
("__itruediv__".into(), (ndarray_truediv_fun_ty, true)),
("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)),
("__imod__".into(), (ndarray_binop_fun_ty, true)),
("__ipow__".into(), (ndarray_binop_fun_ty, true)),
]), ]),
params: VarMap::from([ params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0), (ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
@ -243,7 +290,16 @@ impl TopLevelComposer {
]), ]),
}); });
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap(); unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap();
unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap();
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,

View File

@ -5,7 +5,7 @@ expression: res_vec
[ [
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n", "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [124]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B[typevar113]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar113\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[ [
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [126]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [131]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec expression: res_vec
--- ---
[ [
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[typevar112, typevar113]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar112\", \"typevar113\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [132]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n", "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [140]\n}\n",
] ]

View File

@ -453,8 +453,8 @@ pub fn typeof_binop(
} }
Operator::LShift Operator::LShift
| Operator::RShift | Operator::RShift => lhs,
| Operator::BitOr Operator::BitOr
| Operator::BitXor | Operator::BitXor
| Operator::BitAnd => { | Operator::BitAnd => {
if unifier.unioned(lhs, rhs) { if unifier.unioned(lhs, rhs) {
@ -474,18 +474,21 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
bool: bool_t, bool: bool_t,
uint32: uint32_t, uint32: uint32_t,
uint64: uint64_t, uint64: uint64_t,
ndarray: ndarray_t,
.. ..
} = *store; } = *store;
let size_t = store.usize();
/* int ======== */ /* int ======== */
for t in [int32_t, int64_t, uint32_t, uint64_t] { for t in [int32_t, int64_t, uint32_t, uint64_t] {
impl_basic_arithmetic(unifier, store, t, &[t], Some(t)); let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
impl_pow(unifier, store, t, &[t], Some(t)); impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
impl_bitwise_arithmetic(unifier, store, t); impl_bitwise_arithmetic(unifier, store, t);
impl_bitwise_shift(unifier, store, t); impl_bitwise_shift(unifier, store, t);
impl_div(unifier, store, t, &[t], Some(float_t)); impl_div(unifier, store, t, &[t, ndarray_int_t], None);
impl_floordiv(unifier, store, t, &[t], Some(t)); impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
impl_mod(unifier, store, t, &[t], Some(t)); impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
impl_invert(unifier, store, t, Some(t)); impl_invert(unifier, store, t, Some(t));
impl_not(unifier, store, t, Some(bool_t)); impl_not(unifier, store, t, Some(bool_t));
impl_comparison(unifier, store, t, &[t], Some(bool_t)); impl_comparison(unifier, store, t, &[t], Some(bool_t));
@ -496,11 +499,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
} }
/* float ======== */ /* float ======== */
impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t)); let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t)); let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
impl_div(unifier, store, float_t, &[float_t], Some(float_t)); impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t)); impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t], Some(float_t)); impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
impl_sign(unifier, store, float_t, Some(float_t)); impl_sign(unifier, store, float_t, Some(float_t));
impl_not(unifier, store, float_t, Some(bool_t)); impl_not(unifier, store, float_t, Some(bool_t));
impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t)); impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t));
@ -509,4 +514,15 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
/* bool ======== */ /* bool ======== */
impl_not(unifier, store, bool_t, Some(bool_t)); impl_not(unifier, store, bool_t, Some(bool_t));
impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t)); impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t));
/* ndarray ===== */
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
} }

View File

@ -1203,8 +1203,11 @@ impl<'a> Inferencer<'a> {
right: &ast::Expr<Option<Type>>, right: &ast::Expr<Option<Type>>,
is_aug_assign: bool, is_aug_assign: bool,
) -> InferenceResult { ) -> InferenceResult {
let left_ty = left.custom.unwrap();
let right_ty = right.custom.unwrap();
let method = if let TypeEnum::TObj { fields, .. } = let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref() self.unifier.get_ty_immutable(left_ty).as_ref()
{ {
let (binop_name, binop_assign_name) = ( let (binop_name, binop_assign_name) = (
binop_name(op).into(), binop_name(op).into(),
@ -1219,12 +1222,26 @@ impl<'a> Inferencer<'a> {
} else { } else {
binop_name(op).into() binop_name(op).into()
}; };
let ret = if is_aug_assign {
// The type of augmented assignment operator should never change
Some(left_ty)
} else {
typeof_binop(
self.unifier,
self.primitives,
op,
left_ty,
right_ty,
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
};
self.build_method_call( self.build_method_call(
location, location,
method, method,
left.custom.unwrap(), left_ty,
vec![right.custom.unwrap()], vec![right_ty],
None, ret,
) )
} }

View File

@ -135,10 +135,15 @@ impl TestEnvironment {
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::new(),
}); });
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
let ndarray = unifier.add_ty(TypeEnum::TObj { let ndarray = unifier.add_ty(TypeEnum::TObj {
obj_id: PRIMITIVE_DEF_IDS.ndarray, obj_id: PRIMITIVE_DEF_IDS.ndarray,
fields: HashMap::new(), fields: HashMap::new(),
params: VarMap::new(), params: VarMap::from([
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
]),
}); });
let primitives = PrimitiveStore { let primitives = PrimitiveStore {
int32, int32,

View File

@ -774,12 +774,8 @@ impl Unifier {
// If the types don't match, try to implicitly promote integers // If the types don't match, try to implicitly promote integers
if !self.unioned(ty, value_ty) { if !self.unioned(ty, value_ty) {
let num_val = match *value { let Ok(num_val) = i128::try_from(value.clone()) else {
SymbolValue::I32(v) => v as i128, return Self::incompatible_types(a, b)
SymbolValue::I64(v) => v as i128,
SymbolValue::U32(v) => v as i128,
SymbolValue::U64(v) => v as i128,
_ => return Self::incompatible_types(a, b),
}; };
let can_convert = if self.unioned(ty, primitives.int32) { let can_convert = if self.unioned(ty, primitives.int32) {

View File

@ -6,6 +6,19 @@ def output_int32(x: int32):
def output_float64(x: float): def output_float64(x: float):
... ...
def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]):
for i in range(len(n)):
output_int32(n[i])
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
for i in range(len(n)):
output_float64(n[i])
def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
for r in range(len(n)):
for c in range(len(n[r])):
output_float64(n[r][c])
def consume_ndarray_1(n: ndarray[float, Literal[1]]): def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass pass
@ -19,53 +32,381 @@ def test_ndarray_empty():
def test_ndarray_zeros(): def test_ndarray_zeros():
n: ndarray[float, 1] = np_zeros([1]) n: ndarray[float, 1] = np_zeros([1])
output_float64(n[0]) output_ndarray_float_1(n)
def test_ndarray_ones(): def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1]) n: ndarray[float, 1] = np_ones([1])
output_float64(n[0]) output_ndarray_float_1(n)
def test_ndarray_full(): def test_ndarray_full():
n_float: ndarray[float, 1] = np_full([1], 2.0) n_float: ndarray[float, 1] = np_full([1], 2.0)
output_float64(n_float[0]) output_ndarray_float_1(n_float)
n_i32: ndarray[int32, 1] = np_full([1], 2) n_i32: ndarray[int32, 1] = np_full([1], 2)
output_int32(n_i32[0]) output_ndarray_int32_1(n_i32)
def test_ndarray_eye(): def test_ndarray_eye():
n: ndarray[float, 2] = np_eye(2) n: ndarray[float, 2] = np_eye(2)
n0: ndarray[float, 1] = n[0] output_ndarray_float_2(n)
v: float = n0[0]
output_float64(v)
def test_ndarray_identity(): def test_ndarray_identity():
n: ndarray[float, 2] = np_identity(2) n: ndarray[float, 2] = np_identity(2)
output_float64(n[0][0]) output_ndarray_float_2(n)
output_float64(n[0][1])
output_float64(n[1][0])
output_float64(n[1][1])
def test_ndarray_fill(): def test_ndarray_fill():
n: ndarray[float, 2] = np_empty([2, 2]) n: ndarray[float, 2] = np_empty([2, 2])
n.fill(1.0) n.fill(1.0)
output_float64(n[0][0]) output_ndarray_float_2(n)
output_float64(n[0][1])
output_float64(n[1][0])
output_float64(n[1][1])
def test_ndarray_copy(): def test_ndarray_copy():
x: ndarray[float, 2] = np_identity(2) x: ndarray[float, 2] = np_identity(2)
y = x.copy() y = x.copy()
x.fill(0.0) x.fill(0.0)
output_float64(x[0][0]) output_ndarray_float_2(x)
output_float64(x[0][1]) output_ndarray_float_2(y)
output_float64(x[1][0])
output_float64(x[1][1])
output_float64(y[0][0]) def test_ndarray_add():
output_float64(y[0][1]) x = np_identity(2)
output_float64(y[1][0]) y = x + np_ones([2, 2])
output_float64(y[1][1])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x + np_ones([2])
y = x + np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 1.0 + x
y = 1.0 + x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_add_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x + 1.0
y = x + 1.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_iadd():
x = np_identity(2)
x += np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_iadd_broadcast():
x = np_identity(2)
x += np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_iadd_broadcast_scalar():
x = np_identity(2)
x += 1.0
output_ndarray_float_2(x)
def test_ndarray_sub():
x = np_ones([2, 2])
y = x - np_identity(2)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x - np_ones([2])
y = x - np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 1.0 - x
y = 1.0 - x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_sub_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x - 1
y = x - 1.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_isub():
x = np_ones([2, 2])
x -= np_identity(2)
output_ndarray_float_2(x)
def test_ndarray_isub_broadcast():
x = np_identity(2)
x -= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_isub_broadcast_scalar():
x = np_identity(2)
x -= 1.0
output_ndarray_float_2(x)
def test_ndarray_mul():
x = np_ones([2, 2])
y = x * np_identity(2)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x * np_ones([2])
y = x * np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 2.0 * x
y = 2.0 * x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mul_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x * 2.0
y = x * 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_imul():
x = np_ones([2, 2])
x *= np_identity(2)
output_ndarray_float_2(x)
def test_ndarray_imul_broadcast():
x = np_identity(2)
x *= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_imul_broadcast_scalar():
x = np_identity(2)
x *= 2.0
output_ndarray_float_2(x)
def test_ndarray_truediv():
x = np_identity(2)
y = x / np_ones([2, 2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x / np_ones([2])
y = x / np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 / x
y = 2.0 / x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_truediv_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x / 2.0
y = x / 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_itruediv():
x = np_identity(2)
x /= np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_itruediv_broadcast():
x = np_identity(2)
x /= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_itruediv_broadcast_scalar():
x = np_identity(2)
x /= 2.0
output_ndarray_float_2(x)
def test_ndarray_floordiv():
x = np_identity(2)
y = x // np_ones([2, 2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x // np_ones([2])
y = x // np_ones([2])
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 // x
y = 2.0 // x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_floordiv_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x // 2.0
y = x // 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_ifloordiv():
x = np_identity(2)
x //= np_ones([2, 2])
output_ndarray_float_2(x)
def test_ndarray_ifloordiv_broadcast():
x = np_identity(2)
x //= np_ones([2])
output_ndarray_float_2(x)
def test_ndarray_ifloordiv_broadcast_scalar():
x = np_identity(2)
x //= 2.0
output_ndarray_float_2(x)
def test_ndarray_mod():
x = np_identity(2)
y = x % np_full([2, 2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x % np_ones([2])
y = x % np_full([2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast_lhs_scalar():
x = np_ones([2, 2])
# y: ndarray[float, 2] = 2.0 % x
y = 2.0 % x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_mod_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x % 2.0
y = x % 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_imod():
x = np_identity(2)
x %= np_full([2, 2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_imod_broadcast():
x = np_identity(2)
x %= np_full([2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_imod_broadcast_scalar():
x = np_identity(2)
x %= 2.0
output_ndarray_float_2(x)
def test_ndarray_pow():
x = np_identity(2)
y = x ** np_full([2, 2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast():
x = np_identity(2)
# y: ndarray[float, 2] = x ** np_full([2], 2.0)
y = x ** np_full([2], 2.0)
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast_lhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = 2.0 ** x
y = 2.0 ** x
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_pow_broadcast_rhs_scalar():
x = np_identity(2)
# y: ndarray[float, 2] = x % 2.0
y = x ** 2.0
output_ndarray_float_2(x)
output_ndarray_float_2(y)
def test_ndarray_ipow():
x = np_identity(2)
x **= np_full([2, 2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_ipow_broadcast():
x = np_identity(2)
x **= np_full([2], 2.0)
output_ndarray_float_2(x)
def test_ndarray_ipow_broadcast_scalar():
x = np_identity(2)
x **= 2.0
output_ndarray_float_2(x)
def run() -> int32: def run() -> int32:
test_ndarray_ctor() test_ndarray_ctor()
@ -77,5 +418,54 @@ def run() -> int32:
test_ndarray_identity() test_ndarray_identity()
test_ndarray_fill() test_ndarray_fill()
test_ndarray_copy() test_ndarray_copy()
test_ndarray_add()
test_ndarray_add_broadcast()
test_ndarray_add_broadcast_lhs_scalar()
test_ndarray_add_broadcast_rhs_scalar()
test_ndarray_iadd()
test_ndarray_iadd_broadcast()
test_ndarray_iadd_broadcast_scalar()
test_ndarray_sub()
test_ndarray_sub_broadcast()
test_ndarray_sub_broadcast_lhs_scalar()
test_ndarray_sub_broadcast_rhs_scalar()
test_ndarray_isub()
test_ndarray_isub_broadcast()
test_ndarray_isub_broadcast_scalar()
test_ndarray_mul()
test_ndarray_mul_broadcast()
test_ndarray_mul_broadcast_lhs_scalar()
test_ndarray_mul_broadcast_rhs_scalar()
test_ndarray_imul()
test_ndarray_imul_broadcast()
test_ndarray_imul_broadcast_scalar()
test_ndarray_truediv()
test_ndarray_truediv_broadcast()
test_ndarray_truediv_broadcast_lhs_scalar()
test_ndarray_truediv_broadcast_rhs_scalar()
test_ndarray_itruediv()
test_ndarray_itruediv_broadcast()
test_ndarray_itruediv_broadcast_scalar()
test_ndarray_floordiv()
test_ndarray_floordiv_broadcast()
test_ndarray_floordiv_broadcast_lhs_scalar()
test_ndarray_floordiv_broadcast_rhs_scalar()
test_ndarray_ifloordiv()
test_ndarray_ifloordiv_broadcast()
test_ndarray_ifloordiv_broadcast_scalar()
test_ndarray_mod()
test_ndarray_mod_broadcast()
test_ndarray_mod_broadcast_lhs_scalar()
test_ndarray_mod_broadcast_rhs_scalar()
test_ndarray_imod()
test_ndarray_imod_broadcast()
test_ndarray_imod_broadcast_scalar()
test_ndarray_pow()
test_ndarray_pow_broadcast()
test_ndarray_pow_broadcast_lhs_scalar()
test_ndarray_pow_broadcast_rhs_scalar()
test_ndarray_ipow()
test_ndarray_ipow_broadcast()
test_ndarray_ipow_broadcast_scalar()
return 0 return 0