forked from M-Labs/nac3
core: Implement elementwise binary operators
Including immediate variants of these operators.
This commit is contained in:
parent
3540d0ab29
commit
6af13a8261
|
@ -17,6 +17,7 @@ use crate::{
|
|||
get_llvm_abi_type,
|
||||
irrt::*,
|
||||
llvm_intrinsics::{call_expect, call_float_floor, call_float_pow, call_float_powi},
|
||||
numpy,
|
||||
stmt::{gen_raise, gen_var},
|
||||
CodeGenContext, CodeGenTask,
|
||||
},
|
||||
|
@ -24,7 +25,7 @@ use crate::{
|
|||
toplevel::{
|
||||
DefinitionId,
|
||||
helper::PRIMITIVE_DEF_IDS,
|
||||
numpy::make_ndarray_ty,
|
||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||
TopLevelDef,
|
||||
},
|
||||
typecheck::{
|
||||
|
@ -1129,6 +1130,78 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||
Some("f_pow_i")
|
||||
);
|
||||
Ok(Some(res.into()))
|
||||
} else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let is_ndarray1 = ty1.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
let is_ndarray2 = ty2.obj_id(&ctx.unifier)
|
||||
.is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray);
|
||||
|
||||
if is_ndarray1 && is_ndarray2 {
|
||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty1);
|
||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty2);
|
||||
|
||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
||||
|
||||
let left_val = NDArrayValue::from_ptr_val(
|
||||
left_val.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype1,
|
||||
if is_aug_assign { Some(left_val) } else { None },
|
||||
(left_val.as_ptr_value().into(), false),
|
||||
(right_val, false),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype1), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype2), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype1)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
} else {
|
||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
||||
&mut ctx.unifier,
|
||||
if is_ndarray1 { ty1 } else { ty2 },
|
||||
);
|
||||
let ndarray_val = NDArrayValue::from_ptr_val(
|
||||
if is_ndarray1 { left_val } else { right_val }.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
let res = numpy::ndarray_elementwise_binop_impl(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray_dtype,
|
||||
if is_aug_assign { Some(ndarray_val) } else { None },
|
||||
(left_val, !is_ndarray1),
|
||||
(right_val, !is_ndarray2),
|
||||
|generator, ctx, (lhs, rhs)| {
|
||||
gen_binop_expr_with_values(
|
||||
generator,
|
||||
ctx,
|
||||
(&Some(ndarray_dtype), lhs),
|
||||
op,
|
||||
(&Some(ndarray_dtype), rhs),
|
||||
ctx.current_loc,
|
||||
is_aug_assign,
|
||||
)?.unwrap().to_basic_value_enum(ctx, generator, ndarray_dtype)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(Some(res.as_ptr_value().into()))
|
||||
}
|
||||
} else {
|
||||
let left_ty_enum = ctx.unifier.get_ty_immutable(left_ty.unwrap());
|
||||
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
|
||||
|
|
|
@ -18,6 +18,8 @@ use crate::{
|
|||
CodeGenContext,
|
||||
CodeGenerator,
|
||||
irrt::{
|
||||
call_ndarray_calc_broadcast,
|
||||
call_ndarray_calc_broadcast_index,
|
||||
call_ndarray_calc_nd_indices,
|
||||
call_ndarray_calc_size,
|
||||
},
|
||||
|
@ -338,6 +340,98 @@ fn ndarray_fill_indexed<'ctx, G, ValueFn>(
|
|||
)
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for checking whether the source `ndarray` can be broadcast to the shape of
|
||||
/// the target `ndarray`.
|
||||
fn ndarray_assert_is_broadcastable<'ctx, G: CodeGenerator + ?Sized>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
target: NDArrayValue<'ctx>,
|
||||
source: NDArrayValue<'ctx>,
|
||||
) {
|
||||
let array_ndims = source.load_ndims(ctx);
|
||||
let broadcast_size = target.load_ndims(ctx);
|
||||
|
||||
ctx.make_assert(
|
||||
generator,
|
||||
ctx.builder.build_int_compare(IntPredicate::ULE, array_ndims, broadcast_size, "").unwrap(),
|
||||
"0:ValueError",
|
||||
"operands cannot be broadcast together",
|
||||
[None, None, None],
|
||||
ctx.current_loc,
|
||||
);
|
||||
}
|
||||
|
||||
/// Generates the LLVM IR for populating the entire `NDArray` from two `ndarray` or scalar value
|
||||
/// with broadcast-compatible shapes.
|
||||
fn ndarray_broadcast_fill<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
res: NDArrayValue<'ctx>,
|
||||
lhs: (BasicValueEnum<'ctx>, bool),
|
||||
rhs: (BasicValueEnum<'ctx>, bool),
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator + ?Sized,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (lhs_val, lhs_scalar) = lhs;
|
||||
let (rhs_val, rhs_scalar) = rhs;
|
||||
|
||||
assert!(!(lhs_scalar && rhs_scalar),
|
||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
||||
lhs_val.get_type(),
|
||||
rhs_val.get_type());
|
||||
|
||||
// Assert that all ndarray operands are broadcastable to the target size
|
||||
if !lhs_scalar {
|
||||
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, lhs_val);
|
||||
}
|
||||
|
||||
if !rhs_scalar {
|
||||
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||
ndarray_assert_is_broadcastable(generator, ctx, res, rhs_val);
|
||||
}
|
||||
|
||||
ndarray_fill_indexed(
|
||||
generator,
|
||||
ctx,
|
||||
res,
|
||||
|generator, ctx, idx| {
|
||||
let lhs_elem = if lhs_scalar {
|
||||
lhs_val
|
||||
} else {
|
||||
let lhs = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||
let lhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, lhs, &idx);
|
||||
|
||||
unsafe {
|
||||
lhs.data().get_unchecked(ctx, generator, lhs_idx, None)
|
||||
}
|
||||
};
|
||||
|
||||
let rhs_elem = if rhs_scalar {
|
||||
rhs_val
|
||||
} else {
|
||||
let rhs = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||
let rhs_idx = call_ndarray_calc_broadcast_index(generator, ctx, rhs, &idx);
|
||||
|
||||
unsafe {
|
||||
rhs.data().get_unchecked(ctx, generator, rhs_idx, None)
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert_eq!(lhs_elem.get_type(), rhs_elem.get_type());
|
||||
|
||||
value_fn(generator, ctx, (lhs_elem, rhs_elem))
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`.
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
|
@ -562,6 +656,107 @@ fn ndarray_copy_impl<'ctx, G: CodeGenerator + ?Sized>(
|
|||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// LLVM-typed implementation for computing elementwise binary operations on two input operands.
|
||||
///
|
||||
/// If the operand is a `ndarray`, the broadcast index corresponding to each element in the output
|
||||
/// is computed, the element accessed and used as an operand of the `value_fn` arguments tuple.
|
||||
/// Otherwise, the operand is treated as a scalar value, and is used as an operand of the
|
||||
/// `value_fn` arguments tuple for all output elements.
|
||||
///
|
||||
/// The second element of the tuple indicates whether to treat the operand value as a `ndarray`
|
||||
/// (which would be accessed by its broadcast index) or as a scalar value (which would be
|
||||
/// broadcast to all elements).
|
||||
///
|
||||
/// * `elem_ty` - The element type of the `NDArray`.
|
||||
/// * `res` - The `ndarray` instance to write results into, or [`None`] if the result should be
|
||||
/// written to a new `ndarray`.
|
||||
/// * `value_fn` - Function mapping the two input elements into the result.
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// This function will panic if neither input operands (`lhs` or `rhs`) is a `ndarray`.
|
||||
pub fn ndarray_elementwise_binop_impl<'ctx, G, ValueFn>(
|
||||
generator: &mut G,
|
||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||
elem_ty: Type,
|
||||
res: Option<NDArrayValue<'ctx>>,
|
||||
lhs: (BasicValueEnum<'ctx>, bool),
|
||||
rhs: (BasicValueEnum<'ctx>, bool),
|
||||
value_fn: ValueFn,
|
||||
) -> Result<NDArrayValue<'ctx>, String>
|
||||
where
|
||||
G: CodeGenerator,
|
||||
ValueFn: Fn(&mut G, &mut CodeGenContext<'ctx, '_>, (BasicValueEnum<'ctx>, BasicValueEnum<'ctx>)) -> Result<BasicValueEnum<'ctx>, String>,
|
||||
{
|
||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||
|
||||
let (lhs_val, lhs_scalar) = lhs;
|
||||
let (rhs_val, rhs_scalar) = rhs;
|
||||
|
||||
assert!(!(lhs_scalar && rhs_scalar),
|
||||
"One of the operands must be a ndarray instance: `{}`, `{}`",
|
||||
lhs_val.get_type(),
|
||||
rhs_val.get_type());
|
||||
|
||||
let ndarray = res.unwrap_or_else(|| {
|
||||
if lhs_scalar && rhs_scalar {
|
||||
let lhs_val = NDArrayValue::from_ptr_val(lhs_val.into_pointer_value(), llvm_usize, None);
|
||||
let rhs_val = NDArrayValue::from_ptr_val(rhs_val.into_pointer_value(), llvm_usize, None);
|
||||
|
||||
let ndarray_dims = call_ndarray_calc_broadcast(generator, ctx, lhs_val, rhs_val);
|
||||
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&ndarray_dims,
|
||||
|generator, ctx, v| {
|
||||
Ok(v.size(ctx, generator))
|
||||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.get_typed_unchecked(ctx, generator, idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
} else {
|
||||
let ndarray = NDArrayValue::from_ptr_val(
|
||||
if lhs_scalar { rhs_val } else { lhs_val }.into_pointer_value(),
|
||||
llvm_usize,
|
||||
None,
|
||||
);
|
||||
|
||||
create_ndarray_dyn_shape(
|
||||
generator,
|
||||
ctx,
|
||||
elem_ty,
|
||||
&ndarray,
|
||||
|_, ctx, v| {
|
||||
Ok(v.load_ndims(ctx))
|
||||
},
|
||||
|generator, ctx, v, idx| {
|
||||
unsafe {
|
||||
Ok(v.dim_sizes().get_typed_unchecked(ctx, generator, idx, None))
|
||||
}
|
||||
},
|
||||
).unwrap()
|
||||
}
|
||||
});
|
||||
|
||||
ndarray_broadcast_fill(
|
||||
generator,
|
||||
ctx,
|
||||
ndarray,
|
||||
lhs,
|
||||
rhs,
|
||||
|generator, ctx, elems| {
|
||||
value_fn(generator, ctx, elems)
|
||||
},
|
||||
)?;
|
||||
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// Generates LLVM IR for `ndarray.empty`.
|
||||
pub fn gen_ndarray_empty<'ctx>(
|
||||
context: &mut CodeGenContext<'ctx, '_>,
|
||||
|
|
|
@ -299,6 +299,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
Some("N".into()),
|
||||
None,
|
||||
);
|
||||
let size_t = primitives.0.usize();
|
||||
|
||||
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
||||
let exception_fields = vec![
|
||||
("__name__".into(), int32, true),
|
||||
|
@ -345,8 +347,27 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
.nth(1)
|
||||
.map(|(var_id, ty)| (*ty, *var_id))
|
||||
.unwrap();
|
||||
let ndarray_usized_ndims_tvar = primitives.1.get_fresh_const_generic_var(
|
||||
size_t,
|
||||
Some("ndarray_ndims".into()),
|
||||
None,
|
||||
);
|
||||
let ndarray_copy_ty = *ndarray_fields.get(&"copy".into()).unwrap();
|
||||
let ndarray_fill_ty = *ndarray_fields.get(&"fill".into()).unwrap();
|
||||
let ndarray_add_ty = *ndarray_fields.get(&"__add__".into()).unwrap();
|
||||
let ndarray_sub_ty = *ndarray_fields.get(&"__sub__".into()).unwrap();
|
||||
let ndarray_mul_ty = *ndarray_fields.get(&"__mul__".into()).unwrap();
|
||||
let ndarray_truediv_ty = *ndarray_fields.get(&"__truediv__".into()).unwrap();
|
||||
let ndarray_floordiv_ty = *ndarray_fields.get(&"__floordiv__".into()).unwrap();
|
||||
let ndarray_mod_ty = *ndarray_fields.get(&"__mod__".into()).unwrap();
|
||||
let ndarray_pow_ty = *ndarray_fields.get(&"__pow__".into()).unwrap();
|
||||
let ndarray_iadd_ty = *ndarray_fields.get(&"__iadd__".into()).unwrap();
|
||||
let ndarray_isub_ty = *ndarray_fields.get(&"__isub__".into()).unwrap();
|
||||
let ndarray_imul_ty = *ndarray_fields.get(&"__imul__".into()).unwrap();
|
||||
let ndarray_itruediv_ty = *ndarray_fields.get(&"__itruediv__".into()).unwrap();
|
||||
let ndarray_ifloordiv_ty = *ndarray_fields.get(&"__ifloordiv__".into()).unwrap();
|
||||
let ndarray_imod_ty = *ndarray_fields.get(&"__imod__".into()).unwrap();
|
||||
let ndarray_ipow_ty = *ndarray_fields.get(&"__ipow__".into()).unwrap();
|
||||
|
||||
let top_level_def_list = vec![
|
||||
Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def(
|
||||
|
@ -524,6 +545,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
methods: vec![
|
||||
("copy".into(), ndarray_copy_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 1)),
|
||||
("fill".into(), ndarray_fill_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 2)),
|
||||
("__add__".into(), ndarray_add_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 3)),
|
||||
("__sub__".into(), ndarray_sub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 4)),
|
||||
("__mul__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 5)),
|
||||
("__truediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 6)),
|
||||
("__floordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 7)),
|
||||
("__mod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 8)),
|
||||
("__pow__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 9)),
|
||||
("__iadd__".into(), ndarray_iadd_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 10)),
|
||||
("__isub__".into(), ndarray_isub_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 11)),
|
||||
("__imul__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 12)),
|
||||
("__itruediv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 13)),
|
||||
("__ifloordiv__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 14)),
|
||||
("__imod__".into(), ndarray_mul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 15)),
|
||||
("__ipow__".into(), ndarray_imul_ty.0, DefinitionId(PRIMITIVE_DEF_IDS.ndarray.0 + 16)),
|
||||
],
|
||||
ancestors: Vec::default(),
|
||||
constructor: None,
|
||||
|
@ -562,6 +597,216 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
|
|||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__add__".into(),
|
||||
simple_name: "__add__".into(),
|
||||
signature: ndarray_add_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__sub__".into(),
|
||||
simple_name: "__sub__".into(),
|
||||
signature: ndarray_sub_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__mul__".into(),
|
||||
simple_name: "__mul__".into(),
|
||||
signature: ndarray_mul_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__truediv__".into(),
|
||||
simple_name: "__truediv__".into(),
|
||||
signature: ndarray_truediv_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__floordiv__".into(),
|
||||
simple_name: "__floordiv__".into(),
|
||||
signature: ndarray_floordiv_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__mod__".into(),
|
||||
simple_name: "__mod__".into(),
|
||||
signature: ndarray_mod_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__pow__".into(),
|
||||
simple_name: "__pow__".into(),
|
||||
signature: ndarray_pow_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__iadd__".into(),
|
||||
simple_name: "__iadd__".into(),
|
||||
signature: ndarray_iadd_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id, ndarray_usized_ndims_tvar.1],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__isub__".into(),
|
||||
simple_name: "__isub__".into(),
|
||||
signature: ndarray_isub_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__imul__".into(),
|
||||
simple_name: "__imul__".into(),
|
||||
signature: ndarray_imul_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__itruediv__".into(),
|
||||
simple_name: "__itruediv__".into(),
|
||||
signature: ndarray_itruediv_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__ifloordiv__".into(),
|
||||
simple_name: "__ifloordiv__".into(),
|
||||
signature: ndarray_ifloordiv_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__imod__".into(),
|
||||
simple_name: "__imod__".into(),
|
||||
signature: ndarray_imod_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "ndarray.__ipow__".into(),
|
||||
simple_name: "__ipow__".into(),
|
||||
signature: ndarray_ipow_ty.0,
|
||||
var_id: vec![ndarray_dtype_var_id, ndarray_ndims_var_id],
|
||||
instance_to_symbol: HashMap::default(),
|
||||
instance_to_stmt: HashMap::default(),
|
||||
resolver: None,
|
||||
codegen_callback: Some(Arc::new(GenCall::new(Box::new(
|
||||
|_, _, _, _, _| {
|
||||
unreachable!("handled in gen_expr")
|
||||
},
|
||||
)))),
|
||||
loc: None,
|
||||
})),
|
||||
Arc::new(RwLock::new(TopLevelDef::Function {
|
||||
name: "int32".into(),
|
||||
simple_name: "int32".into(),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::convert::TryInto;
|
||||
|
||||
use crate::symbol_resolver::SymbolValue;
|
||||
use crate::toplevel::numpy::subst_ndarray_tvars;
|
||||
use crate::typecheck::typedef::{Mapping, VarMap};
|
||||
use nac3parser::ast::{Constant, Location};
|
||||
|
||||
|
@ -231,11 +232,57 @@ impl TopLevelComposer {
|
|||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
}));
|
||||
let ndarray_binop_fun_other_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_binop_fun_ret_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_binop_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "other".into(),
|
||||
ty: ndarray_binop_fun_other_ty.0,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret: ndarray_binop_fun_ret_ty.0,
|
||||
vars: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
}));
|
||||
let ndarray_truediv_fun_other_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_truediv_fun_ret_ty = unifier.get_fresh_var(None, None);
|
||||
let ndarray_truediv_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||
args: vec![
|
||||
FuncArg {
|
||||
name: "other".into(),
|
||||
ty: ndarray_truediv_fun_other_ty.0,
|
||||
default_value: None,
|
||||
},
|
||||
],
|
||||
ret: ndarray_truediv_fun_ret_ty.0,
|
||||
vars: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
}));
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
fields: Mapping::from([
|
||||
("copy".into(), (ndarray_copy_fun_ty, true)),
|
||||
("fill".into(), (ndarray_fill_fun_ty, true)),
|
||||
("__add__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__sub__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__mul__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__truediv__".into(), (ndarray_truediv_fun_ty, true)),
|
||||
("__floordiv__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__mod__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__pow__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__iadd__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__isub__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__imul__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__itruediv__".into(), (ndarray_truediv_fun_ty, true)),
|
||||
("__ifloordiv__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__imod__".into(), (ndarray_binop_fun_ty, true)),
|
||||
("__ipow__".into(), (ndarray_binop_fun_ty, true)),
|
||||
]),
|
||||
params: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
|
@ -243,7 +290,16 @@ impl TopLevelComposer {
|
|||
]),
|
||||
});
|
||||
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t_ty, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized = subst_ndarray_tvars(&mut unifier, ndarray, Some(ndarray_usized_ndims_tvar.0), None);
|
||||
|
||||
unifier.unify(ndarray_copy_fun_ret_ty.0, ndarray).unwrap();
|
||||
unifier.unify(ndarray_binop_fun_other_ty.0, ndarray_unsized).unwrap();
|
||||
unifier.unify(ndarray_binop_fun_ret_ty.0, ndarray).unwrap();
|
||||
|
||||
let ndarray_float = subst_ndarray_tvars(&mut unifier, ndarray, Some(float), None);
|
||||
unifier.unify(ndarray_truediv_fun_other_ty.0, ndarray).unwrap();
|
||||
unifier.unify(ndarray_truediv_fun_ret_ty.0, ndarray_float).unwrap();
|
||||
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
|
|
|
@ -5,7 +5,7 @@ expression: res_vec
|
|||
[
|
||||
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
|
||||
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [30]\n}\n",
|
||||
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [124]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -7,7 +7,7 @@ expression: res_vec
|
|||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar19]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar19\"]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B[typevar113]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar113\"]\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -5,8 +5,8 @@ expression: res_vec
|
|||
[
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [32]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [37]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [126]\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [131]\n}\n",
|
||||
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
|
|
|
@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
|
|||
expression: res_vec
|
||||
---
|
||||
[
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar18, typevar19]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar18\", \"typevar19\"]\n}\n",
|
||||
"Class {\nname: \"A\",\nancestors: [\"A[typevar112, typevar113]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar112\", \"typevar113\"]\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",
|
||||
|
|
|
@ -6,12 +6,12 @@ expression: res_vec
|
|||
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [38]\n}\n",
|
||||
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [132]\n}\n",
|
||||
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
|
||||
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [46]\n}\n",
|
||||
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [140]\n}\n",
|
||||
]
|
||||
|
|
|
@ -453,8 +453,8 @@ pub fn typeof_binop(
|
|||
}
|
||||
|
||||
Operator::LShift
|
||||
| Operator::RShift
|
||||
| Operator::BitOr
|
||||
| Operator::RShift => lhs,
|
||||
Operator::BitOr
|
||||
| Operator::BitXor
|
||||
| Operator::BitAnd => {
|
||||
if unifier.unioned(lhs, rhs) {
|
||||
|
@ -474,18 +474,21 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
bool: bool_t,
|
||||
uint32: uint32_t,
|
||||
uint64: uint64_t,
|
||||
ndarray: ndarray_t,
|
||||
..
|
||||
} = *store;
|
||||
let size_t = store.usize();
|
||||
|
||||
/* int ======== */
|
||||
for t in [int32_t, int64_t, uint32_t, uint64_t] {
|
||||
impl_basic_arithmetic(unifier, store, t, &[t], Some(t));
|
||||
impl_pow(unifier, store, t, &[t], Some(t));
|
||||
let ndarray_int_t = make_ndarray_ty(unifier, store, Some(t), None);
|
||||
impl_basic_arithmetic(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_pow(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_bitwise_arithmetic(unifier, store, t);
|
||||
impl_bitwise_shift(unifier, store, t);
|
||||
impl_div(unifier, store, t, &[t], Some(float_t));
|
||||
impl_floordiv(unifier, store, t, &[t], Some(t));
|
||||
impl_mod(unifier, store, t, &[t], Some(t));
|
||||
impl_div(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_floordiv(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_mod(unifier, store, t, &[t, ndarray_int_t], None);
|
||||
impl_invert(unifier, store, t, Some(t));
|
||||
impl_not(unifier, store, t, Some(bool_t));
|
||||
impl_comparison(unifier, store, t, &[t], Some(bool_t));
|
||||
|
@ -496,11 +499,13 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
}
|
||||
|
||||
/* float ======== */
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t], Some(float_t));
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t], Some(float_t));
|
||||
impl_div(unifier, store, float_t, &[float_t], Some(float_t));
|
||||
impl_floordiv(unifier, store, float_t, &[float_t], Some(float_t));
|
||||
impl_mod(unifier, store, float_t, &[float_t], Some(float_t));
|
||||
let ndarray_float_t = make_ndarray_ty(unifier, store, Some(float_t), None);
|
||||
let ndarray_int32_t = make_ndarray_ty(unifier, store, Some(int32_t), None);
|
||||
impl_basic_arithmetic(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_pow(unifier, store, float_t, &[int32_t, float_t, ndarray_int32_t, ndarray_float_t], None);
|
||||
impl_div(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_floordiv(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_mod(unifier, store, float_t, &[float_t, ndarray_float_t], None);
|
||||
impl_sign(unifier, store, float_t, Some(float_t));
|
||||
impl_not(unifier, store, float_t, Some(bool_t));
|
||||
impl_comparison(unifier, store, float_t, &[float_t], Some(bool_t));
|
||||
|
@ -509,4 +514,15 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie
|
|||
/* bool ======== */
|
||||
impl_not(unifier, store, bool_t, Some(bool_t));
|
||||
impl_eq(unifier, store, bool_t, &[bool_t], Some(bool_t));
|
||||
|
||||
/* ndarray ===== */
|
||||
let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None);
|
||||
let ndarray_unsized_t = make_ndarray_ty(unifier, store, None, Some(ndarray_usized_ndims_tvar.0));
|
||||
let (ndarray_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_t);
|
||||
let (ndarray_unsized_dtype_t, _) = unpack_ndarray_var_tys(unifier, ndarray_unsized_t);
|
||||
impl_basic_arithmetic(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_pow(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_div(unifier, store, ndarray_t, &[ndarray_t, ndarray_dtype_t], None);
|
||||
impl_floordiv(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
impl_mod(unifier, store, ndarray_t, &[ndarray_unsized_t, ndarray_unsized_dtype_t], None);
|
||||
}
|
||||
|
|
|
@ -1203,8 +1203,11 @@ impl<'a> Inferencer<'a> {
|
|||
right: &ast::Expr<Option<Type>>,
|
||||
is_aug_assign: bool,
|
||||
) -> InferenceResult {
|
||||
let left_ty = left.custom.unwrap();
|
||||
let right_ty = right.custom.unwrap();
|
||||
|
||||
let method = if let TypeEnum::TObj { fields, .. } =
|
||||
self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
|
||||
self.unifier.get_ty_immutable(left_ty).as_ref()
|
||||
{
|
||||
let (binop_name, binop_assign_name) = (
|
||||
binop_name(op).into(),
|
||||
|
@ -1219,12 +1222,26 @@ impl<'a> Inferencer<'a> {
|
|||
} else {
|
||||
binop_name(op).into()
|
||||
};
|
||||
|
||||
let ret = if is_aug_assign {
|
||||
// The type of augmented assignment operator should never change
|
||||
Some(left_ty)
|
||||
} else {
|
||||
typeof_binop(
|
||||
self.unifier,
|
||||
self.primitives,
|
||||
op,
|
||||
left_ty,
|
||||
right_ty,
|
||||
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?
|
||||
};
|
||||
|
||||
self.build_method_call(
|
||||
location,
|
||||
method,
|
||||
left.custom.unwrap(),
|
||||
vec![right.custom.unwrap()],
|
||||
None,
|
||||
left_ty,
|
||||
vec![right_ty],
|
||||
ret,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -135,10 +135,15 @@ impl TestEnvironment {
|
|||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
});
|
||||
let ndarray_dtype_tvar = unifier.get_fresh_var(Some("ndarray_dtype".into()), None);
|
||||
let ndarray_ndims_tvar = unifier.get_fresh_const_generic_var(uint64, Some("ndarray_ndims".into()), None);
|
||||
let ndarray = unifier.add_ty(TypeEnum::TObj {
|
||||
obj_id: PRIMITIVE_DEF_IDS.ndarray,
|
||||
fields: HashMap::new(),
|
||||
params: VarMap::new(),
|
||||
params: VarMap::from([
|
||||
(ndarray_dtype_tvar.1, ndarray_dtype_tvar.0),
|
||||
(ndarray_ndims_tvar.1, ndarray_ndims_tvar.0),
|
||||
]),
|
||||
});
|
||||
let primitives = PrimitiveStore {
|
||||
int32,
|
||||
|
|
|
@ -774,12 +774,8 @@ impl Unifier {
|
|||
|
||||
// If the types don't match, try to implicitly promote integers
|
||||
if !self.unioned(ty, value_ty) {
|
||||
let num_val = match *value {
|
||||
SymbolValue::I32(v) => v as i128,
|
||||
SymbolValue::I64(v) => v as i128,
|
||||
SymbolValue::U32(v) => v as i128,
|
||||
SymbolValue::U64(v) => v as i128,
|
||||
_ => return Self::incompatible_types(a, b),
|
||||
let Ok(num_val) = i128::try_from(value.clone()) else {
|
||||
return Self::incompatible_types(a, b)
|
||||
};
|
||||
|
||||
let can_convert = if self.unioned(ty, primitives.int32) {
|
||||
|
|
|
@ -6,6 +6,19 @@ def output_int32(x: int32):
|
|||
def output_float64(x: float):
|
||||
...
|
||||
|
||||
def output_ndarray_int32_1(n: ndarray[int32, Literal[1]]):
|
||||
for i in range(len(n)):
|
||||
output_int32(n[i])
|
||||
|
||||
def output_ndarray_float_1(n: ndarray[float, Literal[1]]):
|
||||
for i in range(len(n)):
|
||||
output_float64(n[i])
|
||||
|
||||
def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
|
||||
for r in range(len(n)):
|
||||
for c in range(len(n[r])):
|
||||
output_float64(n[r][c])
|
||||
|
||||
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
|
||||
pass
|
||||
|
||||
|
@ -19,53 +32,381 @@ def test_ndarray_empty():
|
|||
|
||||
def test_ndarray_zeros():
|
||||
n: ndarray[float, 1] = np_zeros([1])
|
||||
output_float64(n[0])
|
||||
output_ndarray_float_1(n)
|
||||
|
||||
def test_ndarray_ones():
|
||||
n: ndarray[float, 1] = np_ones([1])
|
||||
output_float64(n[0])
|
||||
output_ndarray_float_1(n)
|
||||
|
||||
def test_ndarray_full():
|
||||
n_float: ndarray[float, 1] = np_full([1], 2.0)
|
||||
output_float64(n_float[0])
|
||||
output_ndarray_float_1(n_float)
|
||||
n_i32: ndarray[int32, 1] = np_full([1], 2)
|
||||
output_int32(n_i32[0])
|
||||
output_ndarray_int32_1(n_i32)
|
||||
|
||||
def test_ndarray_eye():
|
||||
n: ndarray[float, 2] = np_eye(2)
|
||||
n0: ndarray[float, 1] = n[0]
|
||||
v: float = n0[0]
|
||||
output_float64(v)
|
||||
output_ndarray_float_2(n)
|
||||
|
||||
def test_ndarray_identity():
|
||||
n: ndarray[float, 2] = np_identity(2)
|
||||
output_float64(n[0][0])
|
||||
output_float64(n[0][1])
|
||||
output_float64(n[1][0])
|
||||
output_float64(n[1][1])
|
||||
output_ndarray_float_2(n)
|
||||
|
||||
def test_ndarray_fill():
|
||||
n: ndarray[float, 2] = np_empty([2, 2])
|
||||
n.fill(1.0)
|
||||
output_float64(n[0][0])
|
||||
output_float64(n[0][1])
|
||||
output_float64(n[1][0])
|
||||
output_float64(n[1][1])
|
||||
output_ndarray_float_2(n)
|
||||
|
||||
def test_ndarray_copy():
|
||||
x: ndarray[float, 2] = np_identity(2)
|
||||
y = x.copy()
|
||||
x.fill(0.0)
|
||||
|
||||
output_float64(x[0][0])
|
||||
output_float64(x[0][1])
|
||||
output_float64(x[1][0])
|
||||
output_float64(x[1][1])
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
output_float64(y[0][0])
|
||||
output_float64(y[0][1])
|
||||
output_float64(y[1][0])
|
||||
output_float64(y[1][1])
|
||||
def test_ndarray_add():
|
||||
x = np_identity(2)
|
||||
y = x + np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_add_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x + np_ones([2])
|
||||
y = x + np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_add_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = 1.0 + x
|
||||
y = 1.0 + x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_add_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x + 1.0
|
||||
y = x + 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_iadd():
|
||||
x = np_identity(2)
|
||||
x += np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_iadd_broadcast():
|
||||
x = np_identity(2)
|
||||
x += np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_iadd_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x += 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_sub():
|
||||
x = np_ones([2, 2])
|
||||
y = x - np_identity(2)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_sub_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x - np_ones([2])
|
||||
y = x - np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_sub_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = 1.0 - x
|
||||
y = 1.0 - x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_sub_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x - 1
|
||||
y = x - 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_isub():
|
||||
x = np_ones([2, 2])
|
||||
x -= np_identity(2)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_isub_broadcast():
|
||||
x = np_identity(2)
|
||||
x -= np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_isub_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x -= 1.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_mul():
|
||||
x = np_ones([2, 2])
|
||||
y = x * np_identity(2)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mul_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x * np_ones([2])
|
||||
y = x * np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mul_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = 2.0 * x
|
||||
y = 2.0 * x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mul_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x * 2.0
|
||||
y = x * 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_imul():
|
||||
x = np_ones([2, 2])
|
||||
x *= np_identity(2)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_imul_broadcast():
|
||||
x = np_identity(2)
|
||||
x *= np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_imul_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x *= 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_truediv():
|
||||
x = np_identity(2)
|
||||
y = x / np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_truediv_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x / np_ones([2])
|
||||
y = x / np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_truediv_broadcast_lhs_scalar():
|
||||
x = np_ones([2, 2])
|
||||
# y: ndarray[float, 2] = 2.0 / x
|
||||
y = 2.0 / x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_truediv_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x / 2.0
|
||||
y = x / 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_itruediv():
|
||||
x = np_identity(2)
|
||||
x /= np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_itruediv_broadcast():
|
||||
x = np_identity(2)
|
||||
x /= np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_itruediv_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x /= 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_floordiv():
|
||||
x = np_identity(2)
|
||||
y = x // np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_floordiv_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x // np_ones([2])
|
||||
y = x // np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_floordiv_broadcast_lhs_scalar():
|
||||
x = np_ones([2, 2])
|
||||
# y: ndarray[float, 2] = 2.0 // x
|
||||
y = 2.0 // x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_floordiv_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x // 2.0
|
||||
y = x // 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_ifloordiv():
|
||||
x = np_identity(2)
|
||||
x //= np_ones([2, 2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_ifloordiv_broadcast():
|
||||
x = np_identity(2)
|
||||
x //= np_ones([2])
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_ifloordiv_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x //= 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_mod():
|
||||
x = np_identity(2)
|
||||
y = x % np_full([2, 2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mod_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x % np_ones([2])
|
||||
y = x % np_full([2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mod_broadcast_lhs_scalar():
|
||||
x = np_ones([2, 2])
|
||||
# y: ndarray[float, 2] = 2.0 % x
|
||||
y = 2.0 % x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_mod_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x % 2.0
|
||||
y = x % 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_imod():
|
||||
x = np_identity(2)
|
||||
x %= np_full([2, 2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_imod_broadcast():
|
||||
x = np_identity(2)
|
||||
x %= np_full([2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_imod_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x %= 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_pow():
|
||||
x = np_identity(2)
|
||||
y = x ** np_full([2, 2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_pow_broadcast():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x ** np_full([2], 2.0)
|
||||
y = x ** np_full([2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_pow_broadcast_lhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = 2.0 ** x
|
||||
y = 2.0 ** x
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_pow_broadcast_rhs_scalar():
|
||||
x = np_identity(2)
|
||||
# y: ndarray[float, 2] = x % 2.0
|
||||
y = x ** 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
output_ndarray_float_2(y)
|
||||
|
||||
def test_ndarray_ipow():
|
||||
x = np_identity(2)
|
||||
x **= np_full([2, 2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_ipow_broadcast():
|
||||
x = np_identity(2)
|
||||
x **= np_full([2], 2.0)
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def test_ndarray_ipow_broadcast_scalar():
|
||||
x = np_identity(2)
|
||||
x **= 2.0
|
||||
|
||||
output_ndarray_float_2(x)
|
||||
|
||||
def run() -> int32:
|
||||
test_ndarray_ctor()
|
||||
|
@ -77,5 +418,54 @@ def run() -> int32:
|
|||
test_ndarray_identity()
|
||||
test_ndarray_fill()
|
||||
test_ndarray_copy()
|
||||
test_ndarray_add()
|
||||
test_ndarray_add_broadcast()
|
||||
test_ndarray_add_broadcast_lhs_scalar()
|
||||
test_ndarray_add_broadcast_rhs_scalar()
|
||||
test_ndarray_iadd()
|
||||
test_ndarray_iadd_broadcast()
|
||||
test_ndarray_iadd_broadcast_scalar()
|
||||
test_ndarray_sub()
|
||||
test_ndarray_sub_broadcast()
|
||||
test_ndarray_sub_broadcast_lhs_scalar()
|
||||
test_ndarray_sub_broadcast_rhs_scalar()
|
||||
test_ndarray_isub()
|
||||
test_ndarray_isub_broadcast()
|
||||
test_ndarray_isub_broadcast_scalar()
|
||||
test_ndarray_mul()
|
||||
test_ndarray_mul_broadcast()
|
||||
test_ndarray_mul_broadcast_lhs_scalar()
|
||||
test_ndarray_mul_broadcast_rhs_scalar()
|
||||
test_ndarray_imul()
|
||||
test_ndarray_imul_broadcast()
|
||||
test_ndarray_imul_broadcast_scalar()
|
||||
test_ndarray_truediv()
|
||||
test_ndarray_truediv_broadcast()
|
||||
test_ndarray_truediv_broadcast_lhs_scalar()
|
||||
test_ndarray_truediv_broadcast_rhs_scalar()
|
||||
test_ndarray_itruediv()
|
||||
test_ndarray_itruediv_broadcast()
|
||||
test_ndarray_itruediv_broadcast_scalar()
|
||||
test_ndarray_floordiv()
|
||||
test_ndarray_floordiv_broadcast()
|
||||
test_ndarray_floordiv_broadcast_lhs_scalar()
|
||||
test_ndarray_floordiv_broadcast_rhs_scalar()
|
||||
test_ndarray_ifloordiv()
|
||||
test_ndarray_ifloordiv_broadcast()
|
||||
test_ndarray_ifloordiv_broadcast_scalar()
|
||||
test_ndarray_mod()
|
||||
test_ndarray_mod_broadcast()
|
||||
test_ndarray_mod_broadcast_lhs_scalar()
|
||||
test_ndarray_mod_broadcast_rhs_scalar()
|
||||
test_ndarray_imod()
|
||||
test_ndarray_imod_broadcast()
|
||||
test_ndarray_imod_broadcast_scalar()
|
||||
test_ndarray_pow()
|
||||
test_ndarray_pow_broadcast()
|
||||
test_ndarray_pow_broadcast_lhs_scalar()
|
||||
test_ndarray_pow_broadcast_rhs_scalar()
|
||||
test_ndarray_ipow()
|
||||
test_ndarray_ipow_broadcast()
|
||||
test_ndarray_ipow_broadcast_scalar()
|
||||
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue