Compare commits
7 Commits
bc84d73c25
...
142e2e5cba
Author | SHA1 | Date |
---|---|---|
lyken | 142e2e5cba | |
lyken | 6811c79ae9 | |
lyken | 1850aaeb70 | |
lyken | c287ccc3a1 | |
lyken | cf56c322db | |
lyken | 533fc90960 | |
lyken | 95c54171cc |
File diff suppressed because it is too large
Load Diff
|
@ -1,8 +1,8 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
codegen::{
|
codegen::{
|
||||||
classes::{
|
classes::{
|
||||||
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, NDArrayValue, ProxyType,
|
ArrayLikeIndexer, ArrayLikeValue, ListType, ListValue, ProxyType, ProxyValue,
|
||||||
ProxyValue, RangeValue, UntypedArrayLikeAccessor,
|
RangeValue, UntypedArrayLikeAccessor,
|
||||||
},
|
},
|
||||||
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
|
||||||
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
gen_in_range_check, get_llvm_abi_type, get_llvm_type, get_va_count_arg_name,
|
||||||
|
@ -11,7 +11,7 @@ use crate::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_memcpy_generic,
|
call_memcpy_generic,
|
||||||
},
|
},
|
||||||
need_sret, numpy,
|
need_sret,
|
||||||
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
object::ndarray::{NDArrayOut, ScalarOrNDArray},
|
||||||
stmt::{
|
stmt::{
|
||||||
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
gen_for_callback_incrementing, gen_if_callback, gen_if_else_expr_callback, gen_raise,
|
||||||
|
@ -20,7 +20,7 @@ use crate::{
|
||||||
CodeGenContext, CodeGenTask, CodeGenerator,
|
CodeGenContext, CodeGenTask, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::{SymbolValue, ValueEnum},
|
symbol_resolver::{SymbolValue, ValueEnum},
|
||||||
toplevel::{helper::PrimDef, numpy::unpack_ndarray_var_tys, DefinitionId, TopLevelDef},
|
toplevel::{helper::PrimDef, DefinitionId, TopLevelDef},
|
||||||
typecheck::{
|
typecheck::{
|
||||||
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
magic_methods::{Binop, BinopVariant, HasOpInfo},
|
||||||
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
typedef::{FunSignature, FuncArg, Type, TypeEnum, TypeVarId, Unifier, VarMap},
|
||||||
|
@ -1773,14 +1773,12 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
_ => val.into(),
|
_ => val.into(),
|
||||||
}
|
}
|
||||||
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
} else if ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) {
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let ndarray = AnyObject { value: val, ty };
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let ndarray = NDArrayObject::from_object(generator, ctx, ndarray);
|
||||||
|
|
||||||
let val = NDArrayValue::from_ptr_val(val.into_pointer_value(), llvm_usize, None);
|
|
||||||
|
|
||||||
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
// ndarray uses `~` rather than `not` to perform elementwise inversion, convert it before
|
||||||
// passing it to the elementwise codegen function
|
// passing it to the elementwise codegen function
|
||||||
let op = if ndarray_dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
let op = if ndarray.dtype.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::Bool.id()) {
|
||||||
if op == ast::Unaryop::Invert {
|
if op == ast::Unaryop::Invert {
|
||||||
ast::Unaryop::Not
|
ast::Unaryop::Not
|
||||||
} else {
|
} else {
|
||||||
|
@ -1793,20 +1791,18 @@ pub fn gen_unaryop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
op
|
op
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = numpy::ndarray_elementwise_unaryop_impl(
|
let mapped_ndarray = ndarray.map(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ndarray_dtype,
|
NDArrayOut::NewNDArray { dtype: ndarray.dtype },
|
||||||
None,
|
|generator, ctx, scalar| {
|
||||||
val,
|
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray.dtype), scalar))?
|
||||||
|generator, ctx, val| {
|
|
||||||
gen_unaryop_expr_with_values(generator, ctx, op, (&Some(ndarray_dtype), val))?
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(ctx, generator, ndarray_dtype)
|
.to_basic_value_enum(ctx, generator, ndarray.dtype)
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
res.as_base_value().into()
|
ValueEnum::Dynamic(mapped_ndarray.instance.value.as_basic_value_enum())
|
||||||
} else {
|
} else {
|
||||||
unimplemented!()
|
unimplemented!()
|
||||||
}))
|
}))
|
||||||
|
@ -1849,39 +1845,33 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
if left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
|| right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id())
|
||||||
{
|
{
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let (Some(left_ty), left) = left else { unreachable!() };
|
||||||
|
let (Some(right_ty), right) = comparators[0] else { unreachable!() };
|
||||||
let (Some(left_ty), lhs) = left else { unreachable!() };
|
|
||||||
let (Some(right_ty), rhs) = comparators[0] else { unreachable!() };
|
|
||||||
let op = ops[0];
|
let op = ops[0];
|
||||||
|
|
||||||
let is_ndarray1 =
|
let left = AnyObject { value: left, ty: left_ty };
|
||||||
left_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
let left =
|
||||||
let is_ndarray2 =
|
ScalarOrNDArray::split_object(generator, ctx, left).to_ndarray(generator, ctx);
|
||||||
right_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id());
|
|
||||||
|
|
||||||
return if is_ndarray1 && is_ndarray2 {
|
let right = AnyObject { value: right, ty: right_ty };
|
||||||
let (ndarray_dtype1, _) = unpack_ndarray_var_tys(&mut ctx.unifier, left_ty);
|
let right =
|
||||||
let (ndarray_dtype2, _) = unpack_ndarray_var_tys(&mut ctx.unifier, right_ty);
|
ScalarOrNDArray::split_object(generator, ctx, right).to_ndarray(generator, ctx);
|
||||||
|
|
||||||
assert!(ctx.unifier.unioned(ndarray_dtype1, ndarray_dtype2));
|
let result_ndarray = NDArrayObject::broadcast_starmap(
|
||||||
|
|
||||||
let left_val =
|
|
||||||
NDArrayValue::from_ptr_val(lhs.into_pointer_value(), llvm_usize, None);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
ctx.primitives.bool,
|
&[left, right],
|
||||||
None,
|
NDArrayOut::NewNDArray { dtype: ctx.primitives.bool },
|
||||||
(left_val.as_base_value().into(), false),
|
|generator, ctx, scalars| {
|
||||||
(rhs, false),
|
let left_scalar = scalars[0];
|
||||||
|generator, ctx, (lhs, rhs)| {
|
let right_scalar = scalars[1];
|
||||||
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
let val = gen_cmpop_expr_with_values(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
(Some(ndarray_dtype1), lhs),
|
(Some(left.dtype), left_scalar),
|
||||||
&[op],
|
&[op],
|
||||||
&[(Some(ndarray_dtype2), rhs)],
|
&[(Some(right.dtype), right_scalar)],
|
||||||
)?
|
)?
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_basic_value_enum(
|
.to_basic_value_enum(
|
||||||
|
@ -1894,40 +1884,7 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
return Ok(Some(result_ndarray.instance.value.into()));
|
||||||
} else {
|
|
||||||
let (ndarray_dtype, _) = unpack_ndarray_var_tys(
|
|
||||||
&mut ctx.unifier,
|
|
||||||
if is_ndarray1 { left_ty } else { right_ty },
|
|
||||||
);
|
|
||||||
let res = numpy::ndarray_elementwise_binop_impl(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
None,
|
|
||||||
(lhs, !is_ndarray1),
|
|
||||||
(rhs, !is_ndarray2),
|
|
||||||
|generator, ctx, (lhs, rhs)| {
|
|
||||||
let val = gen_cmpop_expr_with_values(
|
|
||||||
generator,
|
|
||||||
ctx,
|
|
||||||
(Some(ndarray_dtype), lhs),
|
|
||||||
&[op],
|
|
||||||
&[(Some(ndarray_dtype), rhs)],
|
|
||||||
)?
|
|
||||||
.unwrap()
|
|
||||||
.to_basic_value_enum(
|
|
||||||
ctx,
|
|
||||||
generator,
|
|
||||||
ctx.primitives.bool,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(generator.bool_to_i8(ctx, val.into_int_value()).into())
|
|
||||||
},
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(res.as_base_value().into()))
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,6 @@ use crate::{
|
||||||
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
ProxyType, ProxyValue, TypedArrayLikeAccessor, TypedArrayLikeAdapter,
|
||||||
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
TypedArrayLikeMutator, UntypedArrayLikeAccessor, UntypedArrayLikeMutator,
|
||||||
},
|
},
|
||||||
expr::gen_binop_expr_with_values,
|
|
||||||
irrt::{
|
irrt::{
|
||||||
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
calculate_len_for_slice_range, call_ndarray_calc_broadcast,
|
||||||
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
call_ndarray_calc_broadcast_index, call_ndarray_calc_nd_indices,
|
||||||
|
@ -15,9 +14,12 @@ use crate::{
|
||||||
model::*,
|
model::*,
|
||||||
object::{
|
object::{
|
||||||
any::AnyObject,
|
any::AnyObject,
|
||||||
ndarray::{shape_util::parse_numpy_int_sequence, NDArrayObject},
|
ndarray::{nditer::NDIterHandle, shape_util::parse_numpy_int_sequence, NDArrayObject},
|
||||||
|
},
|
||||||
|
stmt::{
|
||||||
|
gen_for_callback, gen_for_callback_incrementing, gen_for_range_callback,
|
||||||
|
gen_if_else_expr_callback,
|
||||||
},
|
},
|
||||||
stmt::{gen_for_callback_incrementing, gen_for_range_callback, gen_if_else_expr_callback},
|
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
symbol_resolver::ValueEnum,
|
symbol_resolver::ValueEnum,
|
||||||
|
@ -26,21 +28,18 @@ use crate::{
|
||||||
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
DefinitionId,
|
DefinitionId,
|
||||||
},
|
},
|
||||||
typecheck::{
|
typecheck::typedef::{FunSignature, Type},
|
||||||
magic_methods::Binop,
|
|
||||||
typedef::{FunSignature, Type},
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::BasicType,
|
types::BasicType,
|
||||||
values::{BasicValueEnum, IntValue, PointerValue},
|
values::{BasicValueEnum, IntValue, PointerValue},
|
||||||
AddressSpace, IntPredicate, OptimizationLevel,
|
AddressSpace, IntPredicate,
|
||||||
};
|
};
|
||||||
use inkwell::{
|
use inkwell::{
|
||||||
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
types::{AnyTypeEnum, BasicTypeEnum, PointerType},
|
||||||
values::BasicValue,
|
values::BasicValue,
|
||||||
};
|
};
|
||||||
use nac3parser::ast::{Operator, StrRef};
|
use nac3parser::ast::StrRef;
|
||||||
|
|
||||||
/// Creates an uninitialized `NDArray` instance.
|
/// Creates an uninitialized `NDArray` instance.
|
||||||
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
fn create_ndarray_uninitialized<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
@ -1708,77 +1707,88 @@ pub fn ndarray_dot<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
) -> Result<BasicValueEnum<'ctx>, String> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
const FN_NAME: &str = "ndarray_dot";
|
const FN_NAME: &str = "ndarray_dot";
|
||||||
let (x1_ty, x1) = x1;
|
let (x1_ty, x1) = x1;
|
||||||
let (_, x2) = x2;
|
let (x2_ty, x2) = x2;
|
||||||
|
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
|
||||||
|
|
||||||
match (x1, x2) {
|
match (x1, x2) {
|
||||||
(BasicValueEnum::PointerValue(n1), BasicValueEnum::PointerValue(n2)) => {
|
(BasicValueEnum::PointerValue(_), BasicValueEnum::PointerValue(_)) => {
|
||||||
let n1 = NDArrayValue::from_ptr_val(n1, llvm_usize, None);
|
let a = AnyObject { ty: x1_ty, value: x1 };
|
||||||
let n2 = NDArrayValue::from_ptr_val(n2, llvm_usize, None);
|
let b = AnyObject { ty: x2_ty, value: x2 };
|
||||||
|
|
||||||
let n1_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let a = NDArrayObject::from_object(generator, ctx, a);
|
||||||
let n2_sz = call_ndarray_calc_size(generator, ctx, &n1.dim_sizes(), (None, None));
|
let b = NDArrayObject::from_object(generator, ctx, b);
|
||||||
|
|
||||||
|
// TODO: General `np.dot()` https://numpy.org/doc/stable/reference/generated/numpy.dot.html.
|
||||||
|
assert_eq!(a.ndims, 1);
|
||||||
|
assert_eq!(b.ndims, 1);
|
||||||
|
let common_dtype = a.dtype;
|
||||||
|
|
||||||
|
// Check shapes.
|
||||||
|
let a_size = a.size(generator, ctx);
|
||||||
|
let b_size = b.size(generator, ctx);
|
||||||
|
let same_shape = a_size.compare(ctx, IntPredicate::EQ, b_size);
|
||||||
ctx.make_assert(
|
ctx.make_assert(
|
||||||
generator,
|
generator,
|
||||||
ctx.builder.build_int_compare(IntPredicate::EQ, n1_sz, n2_sz, "").unwrap(),
|
same_shape.value,
|
||||||
"0:ValueError",
|
"0:ValueError",
|
||||||
"shapes ({0}), ({1}) not aligned",
|
"shapes ({0},) and ({1},) not aligned: {0} (dim 0) != {1} (dim 1)",
|
||||||
[Some(n1_sz), Some(n2_sz), None],
|
[Some(a_size.value), Some(b_size.value), None],
|
||||||
ctx.current_loc,
|
ctx.current_loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
let identity =
|
let dtype_llvm = ctx.get_llvm_type(generator, common_dtype);
|
||||||
unsafe { n1.data().get_unchecked(ctx, generator, &llvm_usize.const_zero(), None) };
|
|
||||||
let acc = ctx.builder.build_alloca(identity.get_type(), "").unwrap();
|
|
||||||
ctx.builder.build_store(acc, identity.get_type().const_zero()).unwrap();
|
|
||||||
|
|
||||||
gen_for_callback_incrementing(
|
let result = ctx.builder.build_alloca(dtype_llvm, "np_dot_result").unwrap();
|
||||||
|
ctx.builder.build_store(result, dtype_llvm.const_zero()).unwrap();
|
||||||
|
|
||||||
|
// Do dot product.
|
||||||
|
gen_for_callback(
|
||||||
generator,
|
generator,
|
||||||
ctx,
|
ctx,
|
||||||
None,
|
Some("np_dot"),
|
||||||
llvm_usize.const_zero(),
|
|generator, ctx| {
|
||||||
(n1_sz, false),
|
let a_iter = NDIterHandle::new(generator, ctx, a);
|
||||||
|generator, ctx, _, idx| {
|
let b_iter = NDIterHandle::new(generator, ctx, b);
|
||||||
let elem1 = unsafe { n1.data().get_unchecked(ctx, generator, &idx, None) };
|
Ok((a_iter, b_iter))
|
||||||
let elem2 = unsafe { n2.data().get_unchecked(ctx, generator, &idx, None) };
|
},
|
||||||
|
|generator, ctx, (a_iter, _b_iter)| {
|
||||||
|
// Only a_iter drives the condition, b_iter should have the same status.
|
||||||
|
Ok(a_iter.has_next(generator, ctx).value)
|
||||||
|
},
|
||||||
|
|generator, ctx, _hooks, (a_iter, b_iter)| {
|
||||||
|
let a_scalar = a_iter.get_scalar(generator, ctx).value;
|
||||||
|
let b_scalar = b_iter.get_scalar(generator, ctx).value;
|
||||||
|
|
||||||
let product = match elem1 {
|
let old_result = ctx.builder.build_load(result, "").unwrap();
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
let new_result: BasicValueEnum<'ctx> = match old_result {
|
||||||
.builder
|
BasicValueEnum::IntValue(old_result) => {
|
||||||
.build_int_mul(e1, elem2.into_int_value(), "")
|
let a_scalar = a_scalar.into_int_value();
|
||||||
.unwrap()
|
let b_scalar = b_scalar.into_int_value();
|
||||||
.as_basic_value_enum(),
|
let x = ctx.builder.build_int_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
ctx.builder.build_int_add(old_result, x, "").unwrap().into()
|
||||||
.builder
|
}
|
||||||
.build_float_mul(e1, elem2.into_float_value(), "")
|
BasicValueEnum::FloatValue(old_result) => {
|
||||||
.unwrap()
|
let a_scalar = a_scalar.into_float_value();
|
||||||
.as_basic_value_enum(),
|
let b_scalar = b_scalar.into_float_value();
|
||||||
_ => unreachable!(),
|
let x = ctx.builder.build_float_mul(a_scalar, b_scalar, "").unwrap();
|
||||||
|
ctx.builder.build_float_add(old_result, x, "").unwrap().into()
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
panic!("Unrecognized dtype: {}", ctx.unifier.stringify(common_dtype));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
|
||||||
let acc_val = match acc_val {
|
|
||||||
BasicValueEnum::IntValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_int_add(e1, product.into_int_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
BasicValueEnum::FloatValue(e1) => ctx
|
|
||||||
.builder
|
|
||||||
.build_float_add(e1, product.into_float_value(), "")
|
|
||||||
.unwrap()
|
|
||||||
.as_basic_value_enum(),
|
|
||||||
_ => unreachable!(),
|
|
||||||
};
|
|
||||||
ctx.builder.build_store(acc, acc_val).unwrap();
|
|
||||||
|
|
||||||
|
ctx.builder.build_store(result, new_result).unwrap();
|
||||||
Ok(())
|
Ok(())
|
||||||
},
|
},
|
||||||
llvm_usize.const_int(1, false),
|
|generator, ctx, (a_iter, b_iter)| {
|
||||||
)?;
|
a_iter.next(generator, ctx);
|
||||||
let acc_val = ctx.builder.build_load(acc, "").unwrap();
|
b_iter.next(generator, ctx);
|
||||||
Ok(acc_val)
|
Ok(())
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Ok(ctx.builder.build_load(result, "").unwrap())
|
||||||
}
|
}
|
||||||
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
(BasicValueEnum::IntValue(e1), BasicValueEnum::IntValue(e2)) => {
|
||||||
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
Ok(ctx.builder.build_int_mul(e1, e2, "").unwrap().as_basic_value_enum())
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
use crate::{
|
||||||
|
codegen::{model::*, CodeGenContext, CodeGenerator},
|
||||||
|
typecheck::typedef::Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::NDArrayObject;
|
||||||
|
|
||||||
|
/// Fields of [`ContiguousNDArray`]
|
||||||
|
pub struct ContiguousNDArrayFields<'ctx, F: FieldTraversal<'ctx>, Item: Model<'ctx>> {
|
||||||
|
pub ndims: F::Out<Int<SizeT>>,
|
||||||
|
pub shape: F::Out<Ptr<Int<SizeT>>>,
|
||||||
|
pub data: F::Out<Ptr<Item>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An ndarray without strides and non-opaque `data` field in NAC3.
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub struct ContiguousNDArray<M> {
|
||||||
|
/// [`Model`] of the items.
|
||||||
|
pub item: M,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx, Item: Model<'ctx>> StructKind<'ctx> for ContiguousNDArray<Item> {
|
||||||
|
type Fields<F: FieldTraversal<'ctx>> = ContiguousNDArrayFields<'ctx, F, Item>;
|
||||||
|
|
||||||
|
fn traverse_fields<F: FieldTraversal<'ctx>>(&self, traversal: &mut F) -> Self::Fields<F> {
|
||||||
|
Self::Fields {
|
||||||
|
ndims: traversal.add_auto("ndims"),
|
||||||
|
shape: traversal.add_auto("shape"),
|
||||||
|
data: traversal.add("data", Ptr(self.item)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'ctx> NDArrayObject<'ctx> {
|
||||||
|
/// Create a [`ContiguousNDArray`] from the contents of this ndarray.
|
||||||
|
///
|
||||||
|
/// This function may or may not be expensive depending on if this ndarray has contiguous data.
|
||||||
|
///
|
||||||
|
/// If this ndarray is not C-contiguous, this function will allocate memory on the stack for the `data` field of
|
||||||
|
/// the returned [`ContiguousNDArray`] and copy contents of this ndarray to there.
|
||||||
|
///
|
||||||
|
/// If this ndarray is C-contiguous, contents of this ndarray will not be copied. The created [`ContiguousNDArray`]
|
||||||
|
/// will share memory with this ndarray.
|
||||||
|
///
|
||||||
|
/// The `item_model` sets the [`Model`] of the returned [`ContiguousNDArray`]'s `Item` model for type-safety, and
|
||||||
|
/// should match the `ctx.get_llvm_type()` of this ndarray's `dtype`. Otherwise this function panics. Use model [`Any`]
|
||||||
|
/// if you don't care/cannot know the [`Model`] in advance.
|
||||||
|
pub fn make_contiguous_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
|
||||||
|
&self,
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
item_model: Item,
|
||||||
|
) -> Instance<'ctx, Ptr<Struct<ContiguousNDArray<Item>>>> {
|
||||||
|
// Sanity check on `self.dtype` and `item_model`.
|
||||||
|
let dtype_llvm = ctx.get_llvm_type(generator, self.dtype);
|
||||||
|
item_model.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
|
||||||
|
|
||||||
|
let cdarray_model = Struct(ContiguousNDArray { item: item_model });
|
||||||
|
|
||||||
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let then_bb = ctx.ctx.insert_basic_block_after(current_bb, "then_bb");
|
||||||
|
let else_bb = ctx.ctx.insert_basic_block_after(then_bb, "else_bb");
|
||||||
|
let end_bb = ctx.ctx.insert_basic_block_after(else_bb, "end_bb");
|
||||||
|
|
||||||
|
// Allocate and setup the resulting [`ContiguousNDArray`].
|
||||||
|
let result = cdarray_model.alloca(generator, ctx);
|
||||||
|
|
||||||
|
// Set ndims and shape.
|
||||||
|
let ndims = self.ndims_llvm(generator, ctx.ctx);
|
||||||
|
result.set(ctx, |f| f.ndims, ndims);
|
||||||
|
|
||||||
|
let shape = self.instance.get(generator, ctx, |f| f.shape);
|
||||||
|
result.set(ctx, |f| f.shape, shape);
|
||||||
|
|
||||||
|
let is_contiguous = self.is_c_contiguous(generator, ctx);
|
||||||
|
ctx.builder.build_conditional_branch(is_contiguous.value, then_bb, else_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into then_bb; This ndarray is contiguous.
|
||||||
|
ctx.builder.position_at_end(then_bb);
|
||||||
|
let data = self.instance.get(generator, ctx, |f| f.data);
|
||||||
|
let data = data.pointer_cast(generator, ctx, item_model);
|
||||||
|
result.set(ctx, |f| f.data, data);
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Inserting into else_bb; This ndarray is not contiguous. Do a full-copy on `data`.
|
||||||
|
// `make_copy` produces an ndarray with contiguous `data`.
|
||||||
|
ctx.builder.position_at_end(else_bb);
|
||||||
|
let copied_ndarray = self.make_copy(generator, ctx);
|
||||||
|
let data = copied_ndarray.instance.get(generator, ctx, |f| f.data);
|
||||||
|
let data = data.pointer_cast(generator, ctx, item_model);
|
||||||
|
result.set(ctx, |f| f.data, data);
|
||||||
|
ctx.builder.build_unconditional_branch(end_bb).unwrap();
|
||||||
|
|
||||||
|
// Reposition to end_bb for continuation
|
||||||
|
ctx.builder.position_at_end(end_bb);
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an [`NDArrayObject`] from a [`ContiguousNDArray`].
|
||||||
|
///
|
||||||
|
/// The operation is super cheap. The newly created [`NDArrayObject`] will share the
|
||||||
|
/// same memory as the [`ContiguousNDArray`].
|
||||||
|
///
|
||||||
|
/// `ndims` has to be provided as [`NDArrayObject`] requires a statically known `ndims` value, despite
|
||||||
|
/// the fact that the information should be contained within the [`ContiguousNDArray`].
|
||||||
|
pub fn from_contiguous_ndarray<G: CodeGenerator + ?Sized, Item: Model<'ctx>>(
|
||||||
|
generator: &mut G,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
|
carray: Instance<'ctx, Ptr<Struct<ContiguousNDArray<Item>>>>,
|
||||||
|
dtype: Type,
|
||||||
|
ndims: u64,
|
||||||
|
) -> Self {
|
||||||
|
// Sanity check on `dtype` and `contiguous_array`'s `Item` model.
|
||||||
|
let dtype_llvm = ctx.get_llvm_type(generator, dtype);
|
||||||
|
carray.model.0 .0.item.check_type(generator, ctx.ctx, dtype_llvm).unwrap();
|
||||||
|
|
||||||
|
// TODO: Debug assert `ndims == carray.ndims` to catch bugs.
|
||||||
|
|
||||||
|
// Allocate the resulting ndarray.
|
||||||
|
let ndarray = NDArrayObject::alloca(generator, ctx, dtype, ndims);
|
||||||
|
|
||||||
|
// Copy shape and update strides
|
||||||
|
let shape = carray.get(generator, ctx, |f| f.shape);
|
||||||
|
ndarray.copy_shape_from_array(generator, ctx, shape);
|
||||||
|
ndarray.set_strides_contiguous(generator, ctx);
|
||||||
|
|
||||||
|
// Share data
|
||||||
|
let data = carray.get(generator, ctx, |f| f.data).pointer_cast(generator, ctx, Int(Byte));
|
||||||
|
ndarray.instance.set(ctx, |f| f.data, data);
|
||||||
|
|
||||||
|
ndarray
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod array;
|
pub mod array;
|
||||||
pub mod broadcast;
|
pub mod broadcast;
|
||||||
|
pub mod contiguous;
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod indexing;
|
pub mod indexing;
|
||||||
pub mod map;
|
pub mod map;
|
||||||
|
@ -26,7 +27,10 @@ use crate::{
|
||||||
model::*,
|
model::*,
|
||||||
CodeGenContext, CodeGenerator,
|
CodeGenContext, CodeGenerator,
|
||||||
},
|
},
|
||||||
toplevel::{helper::extract_ndims, numpy::unpack_ndarray_var_tys},
|
toplevel::{
|
||||||
|
helper::{create_ndims, extract_ndims},
|
||||||
|
numpy::{make_ndarray_ty, unpack_ndarray_var_tys},
|
||||||
|
},
|
||||||
typecheck::typedef::{Type, TypeEnum},
|
typecheck::typedef::{Type, TypeEnum},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -108,6 +112,18 @@ impl<'ctx> NDArrayObject<'ctx> {
|
||||||
Int(SizeT).const_int(generator, ctx, self.ndims)
|
Int(SizeT).const_int(generator, ctx, self.ndims)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the typechecker ndarray type of this [`NDArrayObject`].
|
||||||
|
pub fn get_type(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> Type {
|
||||||
|
let ndims = create_ndims(&mut ctx.unifier, self.ndims);
|
||||||
|
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(self.dtype), Some(ndims))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Forget that this is an ndarray and convert into an [`AnyObject`].
|
||||||
|
pub fn to_any(&self, ctx: &mut CodeGenContext<'ctx, '_>) -> AnyObject<'ctx> {
|
||||||
|
let ty = self.get_type(ctx);
|
||||||
|
AnyObject { value: self.instance.value.as_basic_value_enum(), ty }
|
||||||
|
}
|
||||||
|
|
||||||
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
/// Allocate an ndarray on the stack given its `ndims` and `dtype`.
|
||||||
///
|
///
|
||||||
/// `shape` and `strides` will be automatically allocated on the stack.
|
/// `shape` and `strides` will be automatically allocated on the stack.
|
||||||
|
|
|
@ -2076,10 +2076,12 @@ impl<'a> BuiltinBuilder<'a> {
|
||||||
Box::new(move |ctx, _, fun, args, generator| {
|
Box::new(move |ctx, _, fun, args, generator| {
|
||||||
let x1_ty = fun.0.args[0].ty;
|
let x1_ty = fun.0.args[0].ty;
|
||||||
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
let x1_val = args[0].1.clone().to_basic_value_enum(ctx, generator, x1_ty)?;
|
||||||
|
|
||||||
let x2_ty = fun.0.args[1].ty;
|
let x2_ty = fun.0.args[1].ty;
|
||||||
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
let x2_val = args[1].1.clone().to_basic_value_enum(ctx, generator, x2_ty)?;
|
||||||
|
|
||||||
Ok(Some(ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?))
|
let result = ndarray_dot(generator, ctx, (x1_ty, x1_val), (x2_ty, x2_val))?;
|
||||||
|
Ok(Some(result))
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue