core: WIP - int32 works now
This commit is contained in:
parent
d75f970b4a
commit
5ab59147ca
|
@ -4,6 +4,10 @@ use inkwell::values::{BasicValueEnum, FloatValue, IntValue};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
|
||||||
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics};
|
use crate::codegen::{CodeGenContext, CodeGenerator, extern_fns, irrt, llvm_intrinsics};
|
||||||
|
use crate::codegen::classes::NDArrayValue;
|
||||||
|
use crate::codegen::numpy::ndarray_elementwise_unaryop_impl;
|
||||||
|
use crate::toplevel::helper::PRIMITIVE_DEF_IDS;
|
||||||
|
use crate::toplevel::numpy::unpack_ndarray_var_tys;
|
||||||
use crate::typecheck::typedef::Type;
|
use crate::typecheck::typedef::Type;
|
||||||
|
|
||||||
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
/// Shorthand for [`unreachable!()`] when a type of argument is not supported.
|
||||||
|
@ -21,20 +25,23 @@ fn unsupported_type(
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `int32` builtin function.
|
/// Invokes the `int32` builtin function.
|
||||||
pub fn call_int32<'ctx>(
|
pub fn call_int32<'ctx, G: CodeGenerator + ?Sized>(
|
||||||
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, '_>,
|
ctx: &mut CodeGenContext<'ctx, '_>,
|
||||||
n: (Type, BasicValueEnum<'ctx>),
|
n: (Type, BasicValueEnum<'ctx>),
|
||||||
) -> IntValue<'ctx> {
|
) -> Result<BasicValueEnum<'ctx>, String> {
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
let (n_ty, n) = n;
|
let (n_ty, n) = n;
|
||||||
|
|
||||||
match n.get_type() {
|
Ok(match n.get_type() {
|
||||||
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
|
BasicTypeEnum::IntType(int_ty) if matches!(int_ty.get_bit_width(), 1 | 8) => {
|
||||||
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
debug_assert!(ctx.unifier.unioned(n_ty, ctx.primitives.bool));
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_z_extend(n.into_int_value(), llvm_i32, "zext")
|
.build_int_z_extend(n.into_int_value(), llvm_i32, "zext")
|
||||||
|
.map(Into::into)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,7 +51,7 @@ pub fn call_int32<'ctx>(
|
||||||
ctx.primitives.uint32,
|
ctx.primitives.uint32,
|
||||||
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
].iter().any(|ty| ctx.unifier.unioned(n_ty, *ty)));
|
||||||
|
|
||||||
n.into_int_value()
|
n
|
||||||
}
|
}
|
||||||
|
|
||||||
BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => {
|
BasicTypeEnum::IntType(int_ty) if int_ty.get_bit_width() == 64 => {
|
||||||
|
@ -55,6 +62,7 @@ pub fn call_int32<'ctx>(
|
||||||
|
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_truncate(n.into_int_value(), llvm_i32, "trunc")
|
.build_int_truncate(n.into_int_value(), llvm_i32, "trunc")
|
||||||
|
.map(Into::into)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,11 +74,29 @@ pub fn call_int32<'ctx>(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
ctx.builder
|
ctx.builder
|
||||||
.build_int_truncate(to_int64, llvm_i32, "conv")
|
.build_int_truncate(to_int64, llvm_i32, "conv")
|
||||||
|
.map(Into::into)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
BasicTypeEnum::PointerType(_) if n_ty.obj_id(&ctx.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) => {
|
||||||
|
let (elem_ty, _) = unpack_ndarray_var_tys(&mut ctx.unifier, n_ty);
|
||||||
|
|
||||||
|
let ndarray = ndarray_elementwise_unaryop_impl(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
ctx.primitives.int32,
|
||||||
|
None,
|
||||||
|
NDArrayValue::from_ptr_val(n.into_pointer_value(), llvm_usize, None),
|
||||||
|
|generator, ctx, val| {
|
||||||
|
call_int32(generator, ctx, (elem_ty, val))
|
||||||
|
},
|
||||||
|
)?;
|
||||||
|
|
||||||
|
ndarray.as_ptr_value().into()
|
||||||
|
}
|
||||||
|
|
||||||
_ => unsupported_type(ctx, "int32", &[n_ty])
|
_ => unsupported_type(ctx, "int32", &[n_ty])
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Invokes the `int64` builtin function.
|
/// Invokes the `int64` builtin function.
|
||||||
|
|
|
@ -302,6 +302,12 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
Some("N".into()),
|
Some("N".into()),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
let ndarray_num_ty = make_ndarray_ty(unifier, primitives, Some(num_ty.0), None);
|
||||||
|
let num_or_ndarray_ty = unifier.get_fresh_var_with_range(
|
||||||
|
&[num_ty.0, ndarray_num_ty],
|
||||||
|
Some("T".into()),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
let var_map: VarMap = vec![(num_ty.1, num_ty.0)].into_iter().collect();
|
||||||
let exception_fields = vec![
|
let exception_fields = vec![
|
||||||
|
@ -568,8 +574,8 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
name: "int32".into(),
|
name: "int32".into(),
|
||||||
simple_name: "int32".into(),
|
simple_name: "int32".into(),
|
||||||
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
signature: unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }],
|
args: vec![FuncArg { name: "n".into(), ty: num_or_ndarray_ty.0, default_value: None }],
|
||||||
ret: int32,
|
ret: num_or_ndarray_ty.0,
|
||||||
vars: var_map.clone(),
|
vars: var_map.clone(),
|
||||||
})),
|
})),
|
||||||
var_id: Vec::default(),
|
var_id: Vec::default(),
|
||||||
|
@ -581,7 +587,7 @@ pub fn get_builtins(unifier: &mut Unifier, primitives: &PrimitiveStore) -> Built
|
||||||
let arg_ty = fun.0.args[0].ty;
|
let arg_ty = fun.0.args[0].ty;
|
||||||
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
let arg = args[0].1.clone().to_basic_value_enum(ctx, generator, arg_ty)?;
|
||||||
|
|
||||||
Ok(Some(builtin_fns::call_int32(ctx, (arg_ty, arg)).into()))
|
Ok(Some(builtin_fns::call_int32(generator, ctx, (arg_ty, arg))?))
|
||||||
},
|
},
|
||||||
)))),
|
)))),
|
||||||
loc: None,
|
loc: None,
|
||||||
|
|
|
@ -860,6 +860,51 @@ impl<'a> Inferencer<'a> {
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if id == &"int32".into() && args.len() == 1 {
|
||||||
|
let arg0 = self.fold_expr(args.remove(0))?;
|
||||||
|
let arg0_ty = arg0.custom.unwrap();
|
||||||
|
|
||||||
|
let ret = if arg0_ty.obj_id(self.unifier).is_some_and(|id| id == PRIMITIVE_DEF_IDS.ndarray) {
|
||||||
|
let (_, ndarray_ndims) = unpack_ndarray_var_tys(self.unifier, arg0_ty);
|
||||||
|
|
||||||
|
make_ndarray_ty(
|
||||||
|
self.unifier,
|
||||||
|
self.primitives,
|
||||||
|
Some(self.primitives.int32),
|
||||||
|
Some(ndarray_ndims),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
self.primitives.int32
|
||||||
|
};
|
||||||
|
|
||||||
|
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
|
||||||
|
args: vec![
|
||||||
|
FuncArg {
|
||||||
|
name: "n".into(),
|
||||||
|
ty: arg0.custom.unwrap(),
|
||||||
|
default_value: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
ret,
|
||||||
|
vars: VarMap::new(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
return Ok(Some(Located {
|
||||||
|
location,
|
||||||
|
custom: Some(ret),
|
||||||
|
node: ExprKind::Call {
|
||||||
|
func: Box::new(Located {
|
||||||
|
custom: Some(custom),
|
||||||
|
location: func.location,
|
||||||
|
node: ExprKind::Name { id: *id, ctx: ctx.clone() },
|
||||||
|
}),
|
||||||
|
args: vec![arg0],
|
||||||
|
keywords: vec![],
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
// int64 is special because its argument can be a constant larger than int32
|
// int64 is special because its argument can be a constant larger than int32
|
||||||
if id == &"int64".into() && args.len() == 1 {
|
if id == &"int64".into() && args.len() == 1 {
|
||||||
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
if let ExprKind::Constant { value: ast::Constant::Int(val), kind } =
|
||||||
|
|
|
@ -649,6 +649,13 @@ def test_ndarray_ge_broadcast_rhs_scalar():
|
||||||
output_ndarray_float_2(x)
|
output_ndarray_float_2(x)
|
||||||
output_ndarray_bool_2(y)
|
output_ndarray_bool_2(y)
|
||||||
|
|
||||||
|
def test_ndarray_int32():
|
||||||
|
x = np_identity(2)
|
||||||
|
y = int32(x)
|
||||||
|
|
||||||
|
output_ndarray_float_2(x)
|
||||||
|
output_ndarray_int32_2(y)
|
||||||
|
|
||||||
def run() -> int32:
|
def run() -> int32:
|
||||||
test_ndarray_ctor()
|
test_ndarray_ctor()
|
||||||
test_ndarray_empty()
|
test_ndarray_empty()
|
||||||
|
@ -739,4 +746,6 @@ def run() -> int32:
|
||||||
test_ndarray_ge_broadcast_lhs_scalar()
|
test_ndarray_ge_broadcast_lhs_scalar()
|
||||||
test_ndarray_ge_broadcast_rhs_scalar()
|
test_ndarray_ge_broadcast_rhs_scalar()
|
||||||
|
|
||||||
|
test_ndarray_int32()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue