core: implement np_dot using LLVM_IR
This commit is contained in:
parent
4a6845dac6
commit
54f883f0a5
@ -1865,34 +1865,6 @@ fn build_output_struct<'ctx>(
|
|||||||
out_ptr
|
out_ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `np_dot` linalg function
|
|
||||||
pub fn call_np_dot<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &mut G,
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
x1: (Type, BasicValueEnum<'ctx>),
|
|
||||||
x2: (Type, BasicValueEnum<'ctx>),
|
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
|
||||||
const FN_NAME: &str = "np_dot";
|
|
||||||
let (x1_ty, x1) = x1;
|
|
||||||
let (x2_ty, x2) = x2;
|
|
||||||
|
|
||||||
if let (BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) = (x1, x2) {
|
|
||||||
let (n1_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x1_ty);
|
|
||||||
let n1_elem_ty = ctx.get_llvm_type(generator, n1_elem_ty);
|
|
||||||
let (n2_elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, x2_ty);
|
|
||||||
let n2_elem_ty = ctx.get_llvm_type(generator, n2_elem_ty);
|
|
||||||
|
|
||||||
let (BasicTypeEnum::FloatType(_), BasicTypeEnum::FloatType(_)) = (n1_elem_ty, n2_elem_ty)
|
|
||||||
else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty]);
|
|
||||||
};
|
|
||||||
|
|
||||||
Ok(extern_fns::call_np_dot(ctx, x1, x2, None).into())
|
|
||||||
} else {
|
|
||||||
unsupported_type(ctx, FN_NAME, &[x1_ty, x2_ty])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Invokes the `np_linalg_matmul` linalg function
|
/// Invokes the `np_linalg_matmul` linalg function
|
||||||
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
pub fn call_np_linalg_matmul<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
|
@ -188,33 +188,3 @@ generate_linalg_extern_fn!(call_np_linalg_pinv, "np_linalg_pinv", 2);
|
|||||||
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_lu, "sp_linalg_lu", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_schur, "sp_linalg_schur", 3);
|
||||||
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
generate_linalg_extern_fn!(call_sp_linalg_hessenberg, "sp_linalg_hessenberg", 3);
|
||||||
|
|
||||||
/// Invokes the linalg `np_dot` function.
|
|
||||||
pub fn call_np_dot<'ctx>(
|
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
|
||||||
mat1: BasicValueEnum<'ctx>,
|
|
||||||
mat2: BasicValueEnum<'ctx>,
|
|
||||||
name: Option<&str>,
|
|
||||||
) -> FloatValue<'ctx> {
|
|
||||||
const FN_NAME: &str = "np_dot";
|
|
||||||
|
|
||||||
let extern_fn = ctx.module.get_function(FN_NAME).unwrap_or_else(|| {
|
|
||||||
let fn_type =
|
|
||||||
ctx.ctx.f64_type().fn_type(&[mat1.get_type().into(), mat2.get_type().into()], false);
|
|
||||||
let func = ctx.module.add_function(FN_NAME, fn_type, None);
|
|
||||||
for attr in ["mustprogress", "nofree", "nounwind", "willreturn", "writeonly"] {
|
|
||||||
func.add_attribute(
|
|
||||||
AttributeLoc::Function,
|
|
||||||
ctx.ctx.create_enum_attribute(Attribute::get_named_enum_kind_id(attr), 0),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
func
|
|
||||||
});
|
|
||||||
|
|
||||||
ctx.builder
|
|
||||||
.build_call(extern_fn, &[mat1.into(), mat2.into()], name.unwrap_or_default())
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_float_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap()
|
|
||||||
}
|
|
||||||
|
@ -26,12 +26,15 @@ use crate::{
|
|||||||
typedef::{FunSignature, Type, TypeEnum},
|
typedef::{FunSignature, Type, TypeEnum},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use inkwell::types::{AnyTypeEnum, BasicTypeEnum, PointerType};
|
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate, OptimizationLevel,
|
||||||
};
|
};
|
||||||
|
use inkwell::{
|
||||||
|
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
||||||
|
values::BasicValue,
|
||||||
|
};
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::{Operator, StrRef};
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
@ -2390,7 +2393,7 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
ctx.builder.build_int_compare(IntPredicate::EQ, out_sz, n_sz, "").unwrap(),
|
||||||
"0:ValueError",
|
"0:ValueError",
|
||||||
"cannot reshape array of size {} into provided shape of size {}",
|
"cannot reshape array of size {0} into provided shape of size {1}",
|
||||||
[Some(n_sz), Some(out_sz), None],
|
[Some(n_sz), Some(out_sz), None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
@ -2417,3 +2420,102 @@ pub fn ndarray_reshape<'ctx, G: CodeGenerator + ?Sized>(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generates LLVM IR for `ndarray.dot`.
|
||||||
|
/// Calculate inner product of two vectors or literals
|
||||||
|
/// For matrix multiplication use `np_matmul`
|
||||||
|
///
|
||||||
|
/// The input `NDArray` are flattened and treated as 1D
|
||||||
|
/// The operation is equivalent to np.dot(arr1.ravel(), arr2.ravel())
|
||||||
|
pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
x1: (Type, BasicValueEnum<'ctx>),
|
||||||
|
x2: (Type, BasicValueEnum<'ctx>),
|
||||||
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
|
let (x1_ty, x1) = x1;
|
||||||
|
let (_, x2) = x2;
|
||||||
|
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
match (x1, x2) {
|
||||||
|
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
||||||
|
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
||||||
|
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
||||||
|
|
||||||
|
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
||||||
|
|
||||||
|
ctx.make_assert(
|
||||||
|
generator,
|
||||||
|
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
||||||
|
"0:ValueError",
|
||||||
|
"shapes ({0}), ({1}) not aligned",
|
||||||
|
[Some(n1_sz), Some(n2_sz), None],
|
||||||
|
ctx.current_loc,
|
||||||
|
);
|
||||||
|
|
||||||
|
let identity =
|
||||||
|
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
||||||
|
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
||||||
|
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(n1_sz, false),
|
||||||
|
|generator, ctx, _, idx| {
|
||||||
|
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
||||||
|
|
||||||
|
let product = match elem1 {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_mul(e1, elem2.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_mul(e1, elem2.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
let acc_val = match acc_val {
|
||||||
|
BasicValueEnum::IntValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_int_add(e1, product.into_int_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
BasicValueEnum::FloatValue(e1) => ctx
|
||||||
|
.builder
|
||||||
|
.build_float_add(e1, product.into_float_value(), "")
|
||||||
|
.unwrap()
|
||||||
|
.as_basic_value_enum(),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
ctx.builder.build_store(acc, acc_val).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
||||||
|
Ok(acc_val)
|
||||||
|
}
|
||||||
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
(BasicValueEnum::FloatValue(e1), BasicValueEnum::FloatValue(e2)) => {
|
||||||
|
Ok(ctx.builder.build_float_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
}
|
||||||
|
_ => unreachable!(
|
||||||
|
"{FN_NAME}() not supported for '{}'",
|
||||||
|
format!("'{}'", ctx.unifier.stringify(x1_ty))
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1965,7 +1965,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
self.unifier,
|
self.unifier,
|
||||||
&self.num_or_ndarray_var_map,
|
&self.num_or_ndarray_var_map,
|
||||||
prim.name(),
|
prim.name(),
|
||||||
self.primitives.float,
|
self.num_ty.ty,
|
||||||
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
&[(self.num_or_ndarray_ty.ty, "x1"), (self.num_or_ndarray_ty.ty, "x2")],
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
@ -1973,12 +1973,7 @@ impl<'a> BuiltinBuilder<'a> {
|
|||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_np_dot(
|
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(x1_ty, x1_val),
|
|
||||||
(x2_ty, x2_val),
|
|
||||||
)?))
|
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
@ -1130,6 +1130,44 @@ impl<'a> Inferencer<'a> {
|
|||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id == &"np_dot".into() {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg1 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|
{
|
||||||
|
let (ndarray_dtype, _) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
ndarray_dtype
|
||||||
|
} else {
|
||||||
|
arg0_ty
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg { name: "x1".into(), ty: arg0.custom.unwrap(), default_value: None },
|
||||||
|
FuncArg { name: "x2".into(), ty: arg1.custom.unwrap(), default_value: None },
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: *ctx },
|
||||||
|
}),
|
||||||
|
args: vec![arg0, arg1],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
if ["np_min", "np_max"].iter().any(|fun_id| id == &(*fun_id).into()) && args.len() == 1 {
|
||||||
let arg0 = self.fold_expr(args.remove(0))?;
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
let arg0_ty = arg0.custom.unwrap();
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
@ -34,38 +34,6 @@ impl InputMatrix {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # Safety
|
|
||||||
///
|
|
||||||
/// `mat1` and `mat2` should point to a valid 1DArray of `f64` floats in row-major order
|
|
||||||
#[no_mangle]
|
|
||||||
pub unsafe extern "C" fn np_dot(mat1: *mut InputMatrix, mat2: *mut InputMatrix) -> f64 {
|
|
||||||
let mat1 = mat1.as_mut().unwrap();
|
|
||||||
let mat2 = mat2.as_mut().unwrap();
|
|
||||||
|
|
||||||
if !(mat1.ndims == 1 && mat2.ndims == 1) {
|
|
||||||
let err_msg = format!(
|
|
||||||
"expected 1D Vector Input, but received {}D and {}D input",
|
|
||||||
mat1.ndims, mat2.ndims
|
|
||||||
);
|
|
||||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
let dim1 = (*mat1).get_dims();
|
|
||||||
let dim2 = (*mat2).get_dims();
|
|
||||||
|
|
||||||
if dim1[0] != dim2[0] {
|
|
||||||
let err_msg = format!("shapes ({},) and ({},) not aligned", dim1[0], dim2[0]);
|
|
||||||
report_error("ValueError", "np_dot", file!(), line!(), column!(), &err_msg);
|
|
||||||
}
|
|
||||||
let data_slice1 = unsafe { slice::from_raw_parts_mut(mat1.data, dim1[0]) };
|
|
||||||
let data_slice2 = unsafe { slice::from_raw_parts_mut(mat2.data, dim2[0]) };
|
|
||||||
|
|
||||||
let matrix1 = DMatrix::from_row_slice(dim1[0], 1, data_slice1);
|
|
||||||
let matrix2 = DMatrix::from_row_slice(dim2[0], 1, data_slice2);
|
|
||||||
|
|
||||||
matrix1.dot(&matrix2)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
///
|
///
|
||||||
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
/// `mat1` and `mat2` should point to a valid 2DArray of `f64` floats in row-major order
|
||||||
|
@ -1451,13 +1451,28 @@ def test_ndarray_reshape():
|
|||||||
output_ndarray_float_1(z)
|
output_ndarray_float_1(z)
|
||||||
|
|
||||||
def test_ndarray_dot():
|
def test_ndarray_dot():
|
||||||
x: ndarray[float, 1] = np_array([5.0, 1.0])
|
x1: ndarray[float, 1] = np_array([5.0, 1.0, 4.0, 2.0])
|
||||||
y: ndarray[float, 1] = np_array([5.0, 1.0])
|
y1: ndarray[float, 1] = np_array([5.0, 1.0, 6.0, 6.0])
|
||||||
z = np_dot(x, y)
|
z1 = np_dot(x1, y1)
|
||||||
|
|
||||||
output_ndarray_float_1(x)
|
x2: ndarray[int32, 1] = np_array([5, 1, 4, 2])
|
||||||
output_ndarray_float_1(y)
|
y2: ndarray[int32, 1] = np_array([5, 1, 6, 6])
|
||||||
output_float64(z)
|
z2 = np_dot(x2, y2)
|
||||||
|
|
||||||
|
x3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
|
y3: ndarray[bool, 1] = np_array([True, True, True, True])
|
||||||
|
z3 = np_dot(x3, y3)
|
||||||
|
|
||||||
|
z4 = np_dot(2, 3)
|
||||||
|
z5 = np_dot(2., 3.)
|
||||||
|
z6 = np_dot(True, False)
|
||||||
|
|
||||||
|
output_float64(z1)
|
||||||
|
output_int32(z2)
|
||||||
|
output_bool(z3)
|
||||||
|
output_int32(z4)
|
||||||
|
output_float64(z5)
|
||||||
|
output_bool(z6)
|
||||||
|
|
||||||
def test_ndarray_linalg_matmul():
|
def test_ndarray_linalg_matmul():
|
||||||
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
x: ndarray[float, 2] = np_array([[5.0, 1.0], [1.0, 4.0]])
|
||||||
|
Loading…
Reference in New Issue
Block a user