Compare commits

...

9 Commits

Author SHA1 Message Date
abdul124 43991b150b nac3artiq: allow class attribute access without init function 2024-06-21 10:02:41 +08:00
abdul124 3b6a6f560f core: add support for class attributes 2024-06-21 09:42:02 +08:00
abdul124 e91d24fd10 core: add attribute field to class definition 2024-06-21 09:42:02 +08:00
lyken d89146aa02 core: use no_run on builtin_fns docs 2024-06-20 13:53:25 +08:00
David Mak 5bade81ddb standalone: Add test for multidim array index with one index 2024-06-20 12:50:30 +08:00
David Mak 0452e6de78 core: Fix codegen for tuple-index into ndarray 2024-06-20 12:50:30 +08:00
David Mak 635c944c90 core: Fix type inference for tuple-index into ndarray
Fixes #420.
2024-06-20 12:50:30 +08:00
lyken e36af3b0a3 core: reduce code duplication in codegen/builtin_fns (#422)
Used macros to generate some unary math functions.

Reviewed-on: #422
Reviewed-by: David Mak <chmakac@connect.ust.hk>
Co-authored-by: lyken <lyken@m-labs.hk>
Co-committed-by: lyken <lyken@m-labs.hk>
2024-06-20 12:48:44 +08:00
Sebastien Bourdeauducq 5b1aa812ed update dependencies 2024-06-20 10:43:55 +08:00
17 changed files with 980 additions and 1485 deletions

8
Cargo.lock generated
View File

@ -574,9 +574,9 @@ checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c"
[[package]]
name = "memchr"
version = "2.7.2"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "memoffset"
@ -927,9 +927,9 @@ dependencies = [
[[package]]
name = "redox_syscall"
version = "0.5.1"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e"
checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd"
dependencies = [
"bitflags",
]

View File

@ -2,11 +2,11 @@
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1717196966,
"narHash": "sha256-yZKhxVIKd2lsbOqYd5iDoUIwsRZFqE87smE2Vzf6Ck0=",
"lastModified": 1718530797,
"narHash": "sha256-pup6cYwtgvzDpvpSCFh1TEUjw2zkNpk8iolbKnyFmmU=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "57610d2f8f0937f39dbd72251e9614b1561942d8",
"rev": "b60ebf54c15553b393d144357375ea956f89e9a9",
"type": "github"
},
"original": {

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.ATTR4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

View File

@ -657,7 +657,7 @@ pub fn attributes_writeback(
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name);
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(

View File

@ -627,12 +627,15 @@ impl InnerResolver {
let pyid_to_def = self.pyid_to_def.read();
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
defs.iter().find_map(|def| {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard
{
return *constructor;
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
{
return *constructor;
}
}
}
None
@ -664,7 +667,29 @@ impl InnerResolver {
primitives,
)? {
Ok(s) => s,
Err(e) => return Ok(Err(e)),
Err(e) => {
// Allow access to Class Attributes of Classes without having to initialize Objects
if self.pyid_to_def.read().contains_key(&py_obj_id) {
if let Some(def_id) = self.pyid_to_def.read().get(&py_obj_id).copied() {
let def = defs[def_id.0].read();
let TopLevelDef::Class { object_id, .. } = &*def else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
};
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: VarMap::new(),
fields: HashMap::new(),
};
(unifier.add_ty(ty), true)
} else {
return Ok(Err(e));
}
} else {
return Ok(Err(e));
}
}
};
match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these four types

File diff suppressed because it is too large Load Diff

View File

@ -3,8 +3,8 @@ use std::{collections::HashMap, convert::TryInto, iter::once, iter::zip};
use crate::{
codegen::{
classes::{
ArrayLikeIndexer, ArrayLikeValue, ArraySliceValue, ListValue, NDArrayValue, ProxyValue,
RangeValue, TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
ArrayLikeIndexer, ArrayLikeValue, ListValue, NDArrayValue, ProxyValue, RangeValue,
TypedArrayLikeAccessor, UntypedArrayLikeAccessor,
},
concrete_type::{ConcreteFuncArg, ConcreteTypeEnum, ConcreteTypeStore},
gen_in_range_check, get_llvm_abi_type, get_llvm_type,
@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
}
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {
/// Checks the field and attributes of classes
/// Returns the index of attr in class fields otherwise returns the attribute value
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(),
};
let def = &self.top_level.definitions.read()[obj_id.0];
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
fields.iter().find_position(|x| x.0 == attr).unwrap().0
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) {
(field_index.0, None)
} else {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone()))
}
} else {
unreachable!()
};
index
(index, value)
}
pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize {
match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { fields, .. } => {
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
}
_ => unreachable!(),
}
}
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
@ -1741,22 +1757,37 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
let ndims = values
.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::I32(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
SymbolValue::I64(v) => u64::try_from(v)
.map_err(|_| format!("Expected non-negative literal for ndarray.ndims, got {v}")),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)
})?;
assert!(!ndims.is_empty());
let ndarray_ndims_ty = ctx
.unifier
.get_fresh_literal(ndims.iter().map(|v| SymbolValue::U64(v - 1)).collect(), None);
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
let ndarray_ndims_ty = ctx.unifier.get_fresh_literal(
ndims.iter().map(|v| SymbolValue::U64(v - subscripted_dims)).collect(),
None,
);
let ndarray_ty =
make_ndarray_ty(&mut ctx.unifier, &ctx.primitives, Some(ty), Some(ndarray_ndims_ty));
let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type();
@ -1859,123 +1890,165 @@ fn gen_ndarray_subscript_expr<'ctx, G: CodeGenerator>(
}
};
Ok(Some(match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts
.iter()
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
let index = if let Some(index) = generator.gen_expr(ctx, slice)? {
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value()
} else {
return Ok(None);
};
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let index_addr = generator.gen_var_alloc(ctx, index.get_type().into(), None)?;
ctx.builder.build_store(index_addr, index).unwrap();
if ndims.len() == 1 && ndims[0] == 1 {
// Accessing an element from a 1-dimensional `ndarray`
return Ok(Some(
v.data()
.get(
ctx,
generator,
&ArraySliceValue::from_ptr_val(
index_addr,
llvm_usize.const_int(1, false),
None,
),
None,
)
.into(),
));
}
// Accessing an element from a multi-dimensional `ndarray`
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray = generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
let make_indices_arr = |generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>|
-> Result<_, String> {
Ok(if let ExprKind::Tuple { elts, .. } = &slice.node {
let llvm_int_ty = ctx.get_llvm_type(generator, elts[0].custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
generator,
ctx.builder.build_int_sub(num_dims, llvm_usize.const_int(1, false), "").unwrap(),
);
llvm_int_ty,
llvm_usize.const_int(elts.len() as u64, false),
None,
)?;
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for (i, elt) in elts.iter().enumerate() {
let Some(index) = generator.gen_expr(ctx, elt)? else {
return Ok(None);
};
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
let index = index
.to_basic_value_enum(ctx, generator, elt.custom.unwrap())?
.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else {
return Ok(None);
};
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(i as u64, false),
None,
)
};
ctx.builder.build_store(store_ptr, index).unwrap();
}
Some(index_addr)
} else if let Some(index) = generator.gen_expr(ctx, slice)? {
let llvm_int_ty = ctx.get_llvm_type(generator, slice.custom.unwrap());
let index_addr = generator.gen_array_var_alloc(
ctx,
llvm_int_ty,
llvm_usize.const_int(1u64, false),
None,
)?;
let index =
index.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value();
let Some(index) = normalize_index(generator, ctx, index, 0)? else { return Ok(None) };
let store_ptr = unsafe {
index_addr.ptr_offset_unchecked(ctx, generator, &llvm_usize.const_zero(), None)
};
ctx.builder.build_store(store_ptr, index).unwrap();
Some(index_addr)
} else {
None
})
};
Ok(Some(if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
v.data().get(ctx, generator, &index_addr, None).into()
} else {
match &slice.node {
ExprKind::Tuple { elts, .. } => {
let slices = elts
.iter()
.enumerate()
.map(|(dim, elt)| expr_to_slice(generator, ctx, &elt.node, dim as u64))
.take_while_inclusive(|slice| slice.as_ref().is_ok_and(Option::is_some))
.collect::<Result<Vec<_>, _>>()?;
if slices.len() < elts.len() {
return Ok(None);
}
let slices = slices.into_iter().map(Option::unwrap).collect_vec();
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &slices)?.as_base_value().into()
}
ExprKind::Slice { .. } => {
let Some(slice) = expr_to_slice(generator, ctx, &slice.node, 0)? else {
return Ok(None);
};
numpy::ndarray_sliced_copy(generator, ctx, ty, v, &[slice])?.as_base_value().into()
}
_ => {
// Accessing an element from a multi-dimensional `ndarray`
let Some(index_addr) = make_indices_arr(generator, ctx)? else { return Ok(None) };
// Create a new array, remove the top dimension from the dimension-size-list, and copy the
// elements over
let subscripted_ndarray =
generator.gen_var_alloc(ctx, llvm_ndarray_t.into(), None)?;
let ndarray = NDArrayValue::from_ptr_val(subscripted_ndarray, llvm_usize, None);
let num_dims = v.load_ndims(ctx);
ndarray.store_ndims(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ctx.builder
.build_int_sub(num_dims, llvm_usize.const_int(1, false), "")
.unwrap(),
);
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
let v_data_src_ptr = v.data().ptr_offset(
ctx,
generator,
&ArraySliceValue::from_ptr_val(index_addr, llvm_usize.const_int(1, false), None),
None,
);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_elems, llvm_ndarray_data_t.size_of().unwrap(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
let ndarray_num_dims = ndarray.load_ndims(ctx);
let v_dims_src_ptr = unsafe {
v.dim_sizes().ptr_offset_unchecked(
ctx,
generator,
&llvm_usize.const_int(1, false),
None,
)
};
call_memcpy_generic(
ctx,
ndarray.dim_sizes().base_ptr(ctx, generator),
v_dims_src_ptr,
ctx.builder
.build_int_mul(ndarray_num_dims, llvm_usize.size_of(), "")
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
let ndarray_num_elems = call_ndarray_calc_size(
generator,
ctx,
&ndarray.dim_sizes().as_slice_value(ctx, generator),
(None, None),
);
ndarray.create_data(ctx, llvm_ndarray_data_t, ndarray_num_elems);
let v_data_src_ptr = v.data().ptr_offset(ctx, generator, &index_addr, None);
call_memcpy_generic(
ctx,
ndarray.data().base_ptr(ctx, generator),
v_data_src_ptr,
ctx.builder
.build_int_mul(
ndarray_num_elems,
llvm_ndarray_data_t.size_of().unwrap(),
"",
)
.map(Into::into)
.unwrap(),
llvm_i1.const_zero(),
);
ndarray.as_base_value().into()
}
}
}))
}
@ -2109,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls
// Change Class attribute access requests to accessing constants from Class Definition
if let Some(c) = value.custom {
if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) {
let defs = ctx.top_level.definitions.read();
let result = defs.iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class {
constructor: Some(constructor),
attributes,
..
} = &*rear_guard
{
if *constructor == c {
return attributes.iter().find_map(|f| {
if f.0 == *attr {
// All other checks performed by this point
return Some(f.2.clone());
}
None
});
}
}
}
None
});
match result {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!("Function Type should not have attributes"),
}
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
if fields.is_empty() && params.is_empty() {
let defs = ctx.top_level.definitions.read();
let def = defs[obj_id.0].read();
match if let TopLevelDef::Class { attributes, .. } = &*def {
attributes.iter().find_map(|f| {
if f.0 == *attr {
return Some(f.2.clone());
}
None
})
} else {
None
} {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!(),
}
}
}
}
match generator.gen_expr(ctx, value)? {
Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(
|| {
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
@ -2123,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Ok,
)?,
Some(ValueEnum::Dynamic(v)) => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr);
if let Some(val) = attr_value {
// Change to Constant Construct
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
@ -2306,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Attribute { value, attr, .. } => {
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls
let id = if let TypeEnum::TObj { obj_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{

View File

@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
}
},
ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
} else {

View File

@ -83,6 +83,7 @@ pub fn get_exn_constructor(
object_id: DefinitionId(class_id),
type_vars: Vec::default(),
fields: exception_fields,
attributes: Vec::default(),
methods: vec![("__init__".into(), signature, DefinitionId(cons_id))],
ancestors: vec![
TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() },
@ -596,6 +597,7 @@ impl<'a> BuiltinBuilder<'a> {
object_id: prim.id(),
type_vars: Vec::default(),
fields: make_exception_fields(int32, int64, str),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: vec![],
constructor: None,
@ -624,7 +626,8 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.option_tvar.ty],
fields: vec![],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
@ -738,6 +741,7 @@ impl<'a> BuiltinBuilder<'a> {
object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),

View File

@ -1057,7 +1057,14 @@ impl TopLevelComposer {
let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write();
let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, ..
object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = &mut *class_def
else {
unreachable!("here must be toplevel class def");
@ -1073,10 +1080,14 @@ impl TopLevelComposer {
class_body_ast,
_class_ancestor_def,
class_fields_def,
class_attributes_def,
class_methods_def,
class_type_vars_def,
class_resolver,
) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver);
) = (
*object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars,
resolver,
);
let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.as_ref();
@ -1285,34 +1296,74 @@ impl TopLevelComposer {
.unify(method_dummy_ty, method_type)
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
}
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
ast::StmtKind::AnnAssign { target, annotation, value, .. } => {
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_dummy_var().ty;
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
) =>
{
(slice, false)
let annotation = match value {
None => {
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
) =>
{
(slice, false)
}
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
(slice, true)
}
_ if core_config.kernel_ann.is_none() => (annotation, true),
_ => continue, // ignore fields annotated otherwise
};
class_fields_def.push((*attr, dummy_field_type, mutable));
annotation
}
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
(slice, true)
}
_ if core_config.kernel_ann.is_none() => (annotation, true),
_ => continue, // ignore fields annotated otherwise
};
class_fields_def.push((*attr, dummy_field_type, mutable));
// Supporting Class Attributes
Some(boxed_expr) => {
// Class attributes are set as immutable regardless
let (annotation, _) = match &annotation.node {
ast::ExprKind::Subscript { slice, .. } => (slice, false),
_ if core_config.kernel_ann.is_none() => (annotation, false),
_ => continue,
};
match &**boxed_expr {
ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => {
// Restricting the types allowed to be defined as class attributes
match v {
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
annotation
}
};
let parsed_annotation = parse_ast_to_type_annotation_kinds(
class_resolver,
temp_def_list,
@ -1384,7 +1435,14 @@ impl TopLevelComposer {
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> {
let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, ..
object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = class_def
else {
unreachable!("here must be class def ast")
@ -1393,10 +1451,11 @@ impl TopLevelComposer {
_class_id,
class_ancestor_def,
class_fields_def,
class_attribute_def,
class_methods_def,
_class_type_vars_def,
_class_resolver,
) = (*object_id, ancestors, fields, methods, type_vars, resolver);
) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver);
// since when this function is called, the ancestors of the direct parent
// are supposed to be already handled, so we only need to deal with the direct parent
@ -1407,7 +1466,7 @@ impl TopLevelComposer {
let base = temp_def_list.get(id.0).unwrap();
let base = base.read();
let TopLevelDef::Class { methods, fields, .. } = &*base else {
let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else {
unreachable!("must be top level class def")
};
@ -1449,7 +1508,7 @@ impl TopLevelComposer {
}
}
// use the new_child_methods to replace all the elements in `class_methods_def`
class_methods_def.drain(..);
class_methods_def.clear();
class_methods_def.extend(new_child_methods);
// handle class fields
@ -1459,7 +1518,9 @@ impl TopLevelComposer {
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
// find if there is a fields with the same name in the child class
for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name {
if class_field_name == anc_field_name
|| attributes.iter().any(|f| f.0 == *class_field_name)
{
return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes"
)]));
@ -1467,14 +1528,33 @@ impl TopLevelComposer {
}
new_child_fields.push(to_be_added);
}
// handle class attributes
let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new();
for (anc_attr_name, anc_attr_ty, attr_value) in attributes {
let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone());
// find if there is a attribute with the same name in the child class
for (class_attr_name, ..) in &*class_attribute_def {
if class_attr_name == anc_attr_name
|| fields.iter().any(|f| f.0 == *class_attr_name)
{
return Err(HashSet::from([format!(
"attribute `{class_attr_name}` has already declared in the ancestor classes"
)]));
}
}
new_child_attributes.push(to_be_added);
}
for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
if !is_override.contains(class_field_name) {
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
}
}
class_fields_def.drain(..);
class_fields_def.clear();
class_fields_def.extend(new_child_fields);
class_attribute_def.clear();
class_attribute_def.extend(new_child_attributes);
Ok(())
}

View File

@ -474,6 +474,7 @@ impl TopLevelComposer {
object_id: obj_id,
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
constructor,

View File

@ -103,6 +103,10 @@ pub enum TopLevelDef {
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.

View File

@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds(
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {

View File

@ -34,6 +34,7 @@ pub enum TypeErrorKind {
},
RequiresTypeAnn,
PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
}
#[derive(Debug, Clone)]
@ -156,6 +157,10 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist")
}
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => {
write!(
f,

View File

@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::TopLevelDef;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
@ -1441,6 +1442,24 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
}
/// Checks for non-class attributes
fn infer_general_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
attr: StrRef,
ctx: ExprContext,
) -> InferenceResult {
let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once((
attr.into(),
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
))
.collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty)
}
fn infer_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
@ -1448,31 +1467,72 @@ impl<'a> Inferencer<'a> {
ctx: ExprContext,
) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => {
report_error(&format!("Field `{attr}` is immutable"), value.location)
}
(None, _) => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
(None, mutable) => {
// Check whether it is a class attribute
let defs = self.top_level.definitions.read();
let result = {
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.1);
}
None
})
} else {
None
}
};
match result {
Some(res) if !mutable => Ok(res),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
}
}
}
}
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = {
self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.clone().1);
}
None
});
}
}
}
None
})
};
match result {
Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => {
report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
}
None => self.infer_general_attribute(value, attr, ctx),
}
} else {
let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once((
attr.into(),
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
))
.collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty)
self.infer_general_attribute(value, attr, ctx)
}
}
@ -1586,6 +1646,7 @@ impl<'a> Inferencer<'a> {
fn infer_subscript_ndarray(
&mut self,
value: &ast::Expr<Option<Type>>,
slice: &ast::Expr<Option<Type>>,
dummy_tvar: Type,
ndims: Type,
) -> InferenceResult {
@ -1604,48 +1665,66 @@ impl<'a> Inferencer<'a> {
let ndims = values
.iter()
.map(|ndim| match *ndim {
SymbolValue::U64(v) => Ok(v),
SymbolValue::U32(v) => Ok(u64::from(v)),
SymbolValue::I32(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}"
)])
}),
SymbolValue::I64(v) => u64::try_from(v).map_err(|_| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {v}"
)])
}),
_ => unreachable!(),
})
.collect::<Result<Vec<_>, _>>()?;
.map(|ndim| u64::try_from(ndim.clone()).map_err(|()| ndim.clone()))
.collect::<Result<Vec<_>, _>>()
.map_err(|val| {
HashSet::from([format!(
"Expected non-negative literal for ndarray.ndims, got {}",
i128::try_from(val).unwrap()
)])
})?;
assert!(!ndims.is_empty());
if ndims.len() == 1 && ndims[0] == 1 {
// ndarray[T, Literal[1]] - Index always returns an object of type T
// The number of dimensions subscripted by the index expression.
// Slicing a ndarray will yield the same number of dimensions, whereas indexing into a
// dimension will remove a dimension.
let subscripted_dims = match &slice.node {
ExprKind::Tuple { elts, .. } => elts.iter().fold(0, |acc, value_subexpr| {
if let ExprKind::Slice { .. } = &value_subexpr.node {
acc
} else {
acc + 1
}
}),
ExprKind::Slice { .. } => 0,
_ => 1,
};
if ndims.len() == 1 && ndims[0] - subscripted_dims == 0 {
// ndarray[T, Literal[1]] - Non-Slice index always returns an object of type T
assert_ne!(ndims[0], 0);
Ok(dummy_tvar)
} else {
// ndarray[T, Literal[N]] where N != 1 - Index returns an object of type ndarray[T, Literal[N - 1]]
// Otherwise - Index returns an object of type ndarray[T, Literal[N - subscripted_dims]]
if ndims.iter().any(|v| *v == 0) {
// Disallow subscripting if any Literal value will subscript on an element
let new_ndims = ndims
.into_iter()
.map(|v| {
let v = i128::from(v) - i128::from(subscripted_dims);
u64::try_from(v)
})
.collect::<Result<Vec<_>, _>>()
.map_err(|_| {
HashSet::from([format!(
"Cannot subscript {} by {subscripted_dims} dimensions",
self.unifier.stringify(value.custom.unwrap()),
)])
})?;
if new_ndims.iter().any(|v| *v == 0) {
unimplemented!("Inference for ndarray subscript operator with Literal[0, ...] bound unimplemented")
}
let ndims_min_one_ty = self.unifier.get_fresh_literal(
ndims.into_iter().map(|v| SymbolValue::U64(v - 1)).collect(),
None,
);
let subscripted_ty = make_ndarray_ty(
self.unifier,
self.primitives,
Some(dummy_tvar),
Some(ndims_min_one_ty),
);
let ndims_ty = self
.unifier
.get_fresh_literal(new_ndims.into_iter().map(SymbolValue::U64).collect(), None);
let subscripted_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(dummy_tvar), Some(ndims_ty));
Ok(subscripted_ty)
}
@ -1682,7 +1761,7 @@ impl<'a> Inferencer<'a> {
TypeEnum::TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
let (_, ndims) =
unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
self.infer_subscript_ndarray(value, ty, ndims)
self.infer_subscript_ndarray(value, slice, ty, ndims)
}
_ => {
// the index is a constant, so value can be a sequence.
@ -1725,10 +1804,7 @@ impl<'a> Inferencer<'a> {
}
let (_, ndims) = unpack_ndarray_var_tys(self.unifier, value.custom.unwrap());
let ndarray_ty =
make_ndarray_ty(self.unifier, self.primitives, Some(ty), Some(ndims));
self.constrain(value.custom.unwrap(), ndarray_ty, &value.location)?;
Ok(ndarray_ty)
self.infer_subscript_ndarray(value, slice, ty, ndims)
}
_ => {
if let TypeEnum::TTuple { .. } = &*self.unifier.get_ty(value.custom.unwrap()) {
@ -1763,7 +1839,7 @@ impl<'a> Inferencer<'a> {
.get_fresh_var_with_range(valid_index_tys.as_slice(), None, None)
.ty;
self.constrain(slice.custom.unwrap(), valid_index_ty, &slice.location)?;
self.infer_subscript_ndarray(value, ty, ndims)
self.infer_subscript_ndarray(value, slice, ty, ndims)
}
_ => unreachable!(),
}

View File

@ -289,6 +289,7 @@ impl TestEnvironment {
object_id: DefinitionId(i),
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -331,6 +332,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 1),
type_vars: vec![tvar.ty],
fields: [("a".into(), tvar.ty, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -365,6 +367,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 2),
type_vars: Vec::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -393,6 +396,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 3),
type_vars: Vec::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,

View File

@ -150,6 +150,15 @@ def test_ndarray_slices():
x2 = x[0::2, 0::2]
output_ndarray_float_2(x2)
def test_ndarray_nd_idx():
x = np_identity(2)
x0: float = x[0, 0]
output_float64(x0)
output_float64(x[0, 1])
output_float64(x[1, 0])
output_float64(x[1, 1])
def test_ndarray_add():
x = np_identity(2)
y = x + np_ones([2, 2])
@ -1393,6 +1402,7 @@ def run() -> int32:
test_ndarray_neg_idx()
test_ndarray_slices()
test_ndarray_nd_idx()
test_ndarray_add()
test_ndarray_add_broadcast()