Compare commits

...

7 Commits

20 changed files with 674 additions and 144 deletions

View File

@ -0,0 +1,40 @@
from min_artiq import *
from numpy import int32
@nac3
class Demo:
attr1: KernelInvariant[int32] = 2
attr2: int32 = 4
attr3: Kernel[int32]
@kernel
def __init__(self):
self.attr3 = 8
@nac3
class NAC3Devices:
core: KernelInvariant[Core]
attr4: KernelInvariant[int32] = 16
def __init__(self):
self.core = Core()
@kernel
def run(self):
Demo.attr1 # Supported
# Demo.attr2 # Field not accessible on Kernel
# Demo.attr3 # Only attributes can be accessed in this way
# Demo.attr1 = 2 # Attributes are immutable
self.attr4 # Attributes can be accessed within class
obj = Demo()
obj.attr1 # Attributes can be accessed by class objects
NAC3Devices.attr4 # Attributes accessible for classes without __init__
if __name__ == "__main__":
NAC3Devices().run()

View File

@ -657,7 +657,7 @@ pub fn attributes_writeback(
}
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
attributes.push(name.to_string());
let index = ctx.get_attr_index(ty, *name);
let (index, _) = ctx.get_attr_index(ty, *name);
values.push((
*field_ty,
ctx.build_gep_and_load(

View File

@ -627,12 +627,15 @@ impl InnerResolver {
let pyid_to_def = self.pyid_to_def.read();
let constructor_ty = pyid_to_def.get(&py_obj_id).and_then(|def_id| {
defs.iter().find_map(|def| {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*def.read() {
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { object_id, methods, constructor, .. } = &*rear_guard
{
return *constructor;
if object_id == def_id
&& constructor.is_some()
&& methods.iter().any(|(s, _, _)| s == &"__init__".into())
{
return *constructor;
}
}
}
None
@ -664,7 +667,29 @@ impl InnerResolver {
primitives,
)? {
Ok(s) => s,
Err(e) => return Ok(Err(e)),
Err(e) => {
// Allow access to Class Attributes of Classes without having to initialize Objects
if self.pyid_to_def.read().contains_key(&py_obj_id) {
if let Some(def_id) = self.pyid_to_def.read().get(&py_obj_id).copied() {
let def = defs[def_id.0].read();
let TopLevelDef::Class { object_id, .. } = &*def else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
};
let ty = TypeEnum::TObj {
obj_id: *object_id,
params: VarMap::new(),
fields: HashMap::new(),
};
(unifier.add_ty(ty), true)
} else {
return Ok(Err(e));
}
} else {
return Ok(Err(e));
}
}
};
match (&*unifier.get_ty(extracted_ty), inst_check) {
// do the instantiation for these four types

View File

@ -86,19 +86,35 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
get_subst_key(&mut self.unifier, obj, &fun.vars, filter)
}
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> usize {
/// Checks the field and attributes of classes
/// Returns the index of attr in class fields otherwise returns the attribute value
pub fn get_attr_index(&mut self, ty: Type, attr: StrRef) -> (usize, Option<Constant>) {
let obj_id = match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { obj_id, .. } => *obj_id,
// we cannot have other types, virtual type should be handled by function calls
_ => unreachable!(),
};
let def = &self.top_level.definitions.read()[obj_id.0];
let index = if let TopLevelDef::Class { fields, .. } = &*def.read() {
fields.iter().find_position(|x| x.0 == attr).unwrap().0
let (index, value) = if let TopLevelDef::Class { fields, attributes, .. } = &*def.read() {
if let Some(field_index) = fields.iter().find_position(|x| x.0 == attr) {
(field_index.0, None)
} else {
let attribute_index = attributes.iter().find_position(|x| x.0 == attr).unwrap();
(attribute_index.0, Some(attribute_index.1 .2.clone()))
}
} else {
unreachable!()
};
index
(index, value)
}
pub fn get_attr_index_object(&mut self, ty: Type, attr: StrRef) -> usize {
match &*self.unifier.get_ty(ty) {
TypeEnum::TObj { fields, .. } => {
fields.iter().find_position(|x| *x.0 == attr).unwrap().0
}
_ => unreachable!(),
}
}
pub fn gen_symbol_val<G: CodeGenerator + ?Sized>(
@ -2166,11 +2182,72 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
ExprKind::Attribute { value, attr, .. } => {
// note that we would handle class methods directly in calls
// Change Class attribute access requests to accessing constants from Class Definition
if let Some(c) = value.custom {
if let TypeEnum::TFunc(_) = &*ctx.unifier.get_ty(c) {
let defs = ctx.top_level.definitions.read();
let result = defs.iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class {
constructor: Some(constructor),
attributes,
..
} = &*rear_guard
{
if *constructor == c {
return attributes.iter().find_map(|f| {
if f.0 == *attr {
// All other checks performed by this point
return Some(f.2.clone());
}
None
});
}
}
}
None
});
match result {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!("Function Type should not have attributes"),
}
} else if let TypeEnum::TObj { obj_id, fields, params } = &*ctx.unifier.get_ty(c) {
if fields.is_empty() && params.is_empty() {
let defs = ctx.top_level.definitions.read();
let def = defs[obj_id.0].read();
match if let TopLevelDef::Class { attributes, .. } = &*def {
attributes.iter().find_map(|f| {
if f.0 == *attr {
return Some(f.2.clone());
}
None
})
} else {
None
} {
Some(val) => {
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
None => unreachable!(),
}
}
}
}
match generator.gen_expr(ctx, value)? {
Some(ValueEnum::Static(v)) => v.get_field(*attr, ctx).map_or_else(
|| {
let v = v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?;
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
Ok(ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
@ -2180,7 +2257,14 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
Ok,
)?,
Some(ValueEnum::Dynamic(v)) => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, attr_value) = ctx.get_attr_index(value.custom.unwrap(), *attr);
if let Some(val) = attr_value {
// Change to Constant Construct
let mut modified_expr = expr.clone();
modified_expr.node = ExprKind::Constant { value: val, kind: None };
return generator.gen_expr(ctx, &modified_expr);
}
ValueEnum::Dynamic(ctx.build_gep_and_load(
v.into_pointer_value(),
&[zero, int32.const_int(index as u64, false)],
@ -2363,6 +2447,7 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ExprKind::Attribute { value, attr, .. } => {
let Some(val) = generator.gen_expr(ctx, value)? else { return Ok(None) };
// Handle Class Method calls
let id = if let TypeEnum::TObj { obj_id, .. } =
&*ctx.unifier.get_ty(value.custom.unwrap())
{

View File

@ -163,10 +163,11 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
) -> Result<NDArrayValue<'ctx>, String> {
let llvm_usize = generator.get_size_type(ctx.ctx);
for shape_dim in shape {
for &shape_dim in shape {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let shape_dim_gez = ctx
.builder
.build_int_compare(IntPredicate::SGE, *shape_dim, llvm_usize.const_zero(), "")
.build_int_compare(IntPredicate::SGE, shape_dim, llvm_usize.const_zero(), "")
.unwrap();
ctx.make_assert(
@ -189,7 +190,8 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
let ndarray_num_dims = ndarray.load_ndims(ctx);
ndarray.create_dim_sizes(ctx, llvm_usize, ndarray_num_dims);
for (i, shape_dim) in shape.iter().enumerate() {
for (i, &shape_dim) in shape.iter().enumerate() {
let shape_dim = ctx.builder.build_int_z_extend(shape_dim, llvm_usize, "").unwrap();
let ndarray_dim = unsafe {
ndarray.dim_sizes().ptr_offset_unchecked(
ctx,
@ -199,7 +201,7 @@ fn create_ndarray_const_shape<'ctx, G: CodeGenerator + ?Sized>(
)
};
ctx.builder.build_store(ndarray_dim, *shape_dim).unwrap();
ctx.builder.build_store(ndarray_dim, shape_dim).unwrap();
}
let ndarray = ndarray_init_data(generator, ctx, elem_ty, ndarray);
@ -286,22 +288,68 @@ fn ndarray_one_value<'ctx, G: CodeGenerator + ?Sized>(
///
/// * `elem_ty` - The element type of the `NDArray`.
/// * `shape` - The `shape` parameter used to construct the `NDArray`.
///
/// ### Notes on `shape`
///
/// Just like numpy, the `shape` argument can be:
/// 1. A list of `int32`; e.g., `np.empty([600, 800, 3])`
/// 2. A tuple of `int32`; e.g., `np.empty((600, 800, 3))`
/// 3. A scalar `int32`; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
///
/// See also [`typecheck::type_inferencer::fold_numpy_function_call_shape_argument`] to
/// learn how `shape` gets from being a Python user expression to here.
fn call_ndarray_empty_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape,
|_, ctx, shape| Ok(shape.load_size(ctx, None)),
|generator, ctx, shape, idx| {
Ok(shape.data().get(ctx, generator, &idx, None).into_int_value())
},
)
let llvm_usize = generator.get_size_type(ctx.ctx);
match shape {
BasicValueEnum::PointerValue(shape_list_ptr)
if ListValue::is_instance(shape_list_ptr, llvm_usize).is_ok() =>
{
// 1. A list of ints; e.g., `np.empty([600, 800, 3])`
let shape_list = ListValue::from_ptr_val(shape_list_ptr, llvm_usize, None);
create_ndarray_dyn_shape(
generator,
ctx,
elem_ty,
&shape_list,
|_, ctx, shape_list| Ok(shape_list.load_size(ctx, None)),
|generator, ctx, shape_list, idx| {
Ok(shape_list.data().get(ctx, generator, &idx, None).into_int_value())
},
)
}
BasicValueEnum::StructValue(shape_tuple) => {
// 2. A tuple of ints; e.g., `np.empty((600, 800, 3))`
// Read [`codegen::expr::gen_expr`] to see how `nac3core` translates a Python tuple into LLVM.
// Get the length/size of the tuple, which also happens to be the value of `ndims`.
let ndims = shape_tuple.get_type().count_fields();
let mut shape = Vec::with_capacity(ndims as usize);
for dim_i in 0..ndims {
let dim = ctx
.builder
.build_extract_value(shape_tuple, dim_i, format!("dim{dim_i}").as_str())
.unwrap()
.into_int_value();
shape.push(dim);
}
create_ndarray_const_shape(generator, ctx, elem_ty, shape.as_slice())
}
BasicValueEnum::IntValue(shape_int) => {
// 3. A scalar int; e.g., `np.empty(3)`, this is functionally equivalent to `np.empty([3])`
create_ndarray_const_shape(generator, ctx, elem_ty, &[shape_int])
}
_ => unreachable!(),
}
}
/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as
@ -486,7 +534,7 @@ fn call_ndarray_zeros_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
@ -517,7 +565,7 @@ fn call_ndarray_ones_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let supported_types = [
ctx.primitives.int32,
@ -548,7 +596,7 @@ fn call_ndarray_full_impl<'ctx, G: CodeGenerator + ?Sized>(
generator: &mut G,
ctx: &mut CodeGenContext<'ctx, '_>,
elem_ty: Type,
shape: ListValue<'ctx>,
shape: BasicValueEnum<'ctx>,
fill_value: BasicValueEnum<'ctx>,
) -> Result<NDArrayValue<'ctx>, String> {
let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?;
@ -1674,17 +1722,11 @@ pub fn gen_ndarray_empty<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_empty_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_empty_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.zeros`.
@ -1698,17 +1740,11 @@ pub fn gen_ndarray_zeros<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_zeros_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_zeros_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.ones`.
@ -1722,17 +1758,11 @@ pub fn gen_ndarray_ones<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 1);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
call_ndarray_ones_impl(
generator,
context,
context.primitives.float,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
)
.map(NDArrayValue::into)
call_ndarray_ones_impl(generator, context, context.primitives.float, shape_arg)
.map(NDArrayValue::into)
}
/// Generates LLVM IR for `ndarray.full`.
@ -1746,21 +1776,14 @@ pub fn gen_ndarray_full<'ctx>(
assert!(obj.is_none());
assert_eq!(args.len(), 2);
let llvm_usize = generator.get_size_type(context.ctx);
let shape_ty = fun.0.args[0].ty;
let shape_arg = args[0].1.clone().to_basic_value_enum(context, generator, shape_ty)?;
let fill_value_ty = fun.0.args[1].ty;
let fill_value_arg =
args[1].1.clone().to_basic_value_enum(context, generator, fill_value_ty)?;
call_ndarray_full_impl(
generator,
context,
fill_value_ty,
ListValue::from_ptr_val(shape_arg.into_pointer_value(), llvm_usize, None),
fill_value_arg,
)
.map(NDArrayValue::into)
call_ndarray_full_impl(generator, context, fill_value_ty, shape_arg, fill_value_arg)
.map(NDArrayValue::into)
}
pub fn gen_ndarray_array<'ctx>(

View File

@ -113,7 +113,7 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>(
}
},
ExprKind::Attribute { value, attr, .. } => {
let index = ctx.get_attr_index(value.custom.unwrap(), *attr);
let (index, _) = ctx.get_attr_index(value.custom.unwrap(), *attr);
let val = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?
} else {

View File

@ -83,6 +83,7 @@ pub fn get_exn_constructor(
object_id: DefinitionId(class_id),
type_vars: Vec::default(),
fields: exception_fields,
attributes: Vec::default(),
methods: vec![("__init__".into(), signature, DefinitionId(cons_id))],
ancestors: vec![
TypeAnnotation::CustomClass { id: DefinitionId(class_id), params: Vec::default() },
@ -323,6 +324,9 @@ struct BuiltinBuilder<'a> {
num_or_ndarray_ty: TypeVar,
num_or_ndarray_var_map: VarMap,
/// See [`BuiltinBuilder::build_ndarray_from_shape_factory_function`]
ndarray_factory_fn_shape_arg_tvar: TypeVar,
}
impl<'a> BuiltinBuilder<'a> {
@ -393,6 +397,8 @@ impl<'a> BuiltinBuilder<'a> {
let list_int32 = unifier.add_ty(TypeEnum::TList { ty: int32 });
let ndarray_factory_fn_shape_arg_tvar = unifier.get_fresh_var(Some("Shape".into()), None);
BuiltinBuilder {
unifier,
primitives,
@ -420,6 +426,8 @@ impl<'a> BuiltinBuilder<'a> {
num_or_ndarray_ty,
num_or_ndarray_var_map,
ndarray_factory_fn_shape_arg_tvar,
}
}
@ -596,6 +604,7 @@ impl<'a> BuiltinBuilder<'a> {
object_id: prim.id(),
type_vars: Vec::default(),
fields: make_exception_fields(int32, int64, str),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: vec![],
constructor: None,
@ -624,7 +633,8 @@ impl<'a> BuiltinBuilder<'a> {
name: prim.name().into(),
object_id: prim.id(),
type_vars: vec![self.option_tvar.ty],
fields: vec![],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
Self::create_method(PrimDef::OptionIsSome, self.is_some_ty.0),
Self::create_method(PrimDef::OptionIsNone, self.is_some_ty.0),
@ -738,6 +748,7 @@ impl<'a> BuiltinBuilder<'a> {
object_id: prim.id(),
type_vars: vec![self.ndarray_dtype_tvar.ty, self.ndarray_ndims_tvar.ty],
fields: Vec::default(),
attributes: Vec::default(),
methods: vec![
Self::create_method(PrimDef::NDArrayCopy, self.ndarray_copy_ty.0),
Self::create_method(PrimDef::NDArrayFill, self.ndarray_fill_ty.0),
@ -955,21 +966,46 @@ impl<'a> BuiltinBuilder<'a> {
)
}
/// Build ndarray factory functions that only take in an argument `shape` of type `list[int32]` and return an ndarray.
/// Build ndarray factory functions that only take in an argument `shape`.
///
/// `shape` can be a tuple of int32s, a list of int32s, or a scalar int32.
fn build_ndarray_from_shape_factory_function(&mut self, prim: PrimDef) -> TopLevelDef {
debug_assert_prim_is_allowed(
prim,
&[PrimDef::FunNpNDArray, PrimDef::FunNpEmpty, PrimDef::FunNpZeros, PrimDef::FunNpOnes],
);
// NOTE: on `ndarray_factory_fn_shape_arg_tvar` and
// the `param_ty` for `create_fn_by_codegen`.
//
// Ideally, we should have created a [`TypeVar`] to define all possible input
// types for the parameter "shape" like so:
// ```rust
// self.unifier.get_fresh_var_with_range(
// &[int32, list_int32, /* and more... */],
// Some("T".into()), None)
// )
// ```
//
// However, there is (currently) no way to type a tuple of arbitrary length in `nac3core`.
//
// And this is the best we could do:
// ```rust
// &[ int32, list_int32, tuple_1_int32, tuple_2_int32, tuple_3_int32, ... ],
// ```
//
// But this is not ideal.
//
// Instead, we delegate the responsibility of typechecking
// to [`typecheck::type_inferencer::Inferencer::fold_numpy_function_call_shape_argument`],
// and use a dummy [`TypeVar`] `ndarray_factory_fn_shape_arg_tvar` as a placeholder for `param_ty`.
create_fn_by_codegen(
self.unifier,
&VarMap::new(),
prim.name(),
self.ndarray_float,
// We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a
// type variable
&[(self.list_int32, "shape")],
&[(self.ndarray_factory_fn_shape_arg_tvar.ty, "shape")],
Box::new(move |ctx, obj, fun, args, generator| {
let func = match prim {
PrimDef::FunNpNDArray | PrimDef::FunNpEmpty => gen_ndarray_empty,

View File

@ -1057,7 +1057,14 @@ impl TopLevelComposer {
let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write();
let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, ..
object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = &mut *class_def
else {
unreachable!("here must be toplevel class def");
@ -1073,10 +1080,14 @@ impl TopLevelComposer {
class_body_ast,
_class_ancestor_def,
class_fields_def,
class_attributes_def,
class_methods_def,
class_type_vars_def,
class_resolver,
) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver);
) = (
*object_id, *name, bases, body, ancestors, fields, attributes, methods, type_vars,
resolver,
);
let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.as_ref();
@ -1285,34 +1296,74 @@ impl TopLevelComposer {
.unify(method_dummy_ty, method_type)
.map_err(|e| HashSet::from([e.to_display(unifier).to_string()]))?;
}
ast::StmtKind::AnnAssign { target, annotation, value: None, .. } => {
ast::StmtKind::AnnAssign { target, annotation, value, .. } => {
if let ast::ExprKind::Name { id: attr, .. } = &target.node {
if defined_fields.insert(attr.to_string()) {
let dummy_field_type = unifier.get_dummy_var().ty;
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
) =>
{
(slice, false)
let annotation = match value {
None => {
// handle Kernel[T], KernelInvariant[T]
let (annotation, mutable) = match &annotation.node {
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if id == &core_config.kernel_invariant_ann.into()
) =>
{
(slice, false)
}
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
(slice, true)
}
_ if core_config.kernel_ann.is_none() => (annotation, true),
_ => continue, // ignore fields annotated otherwise
};
class_fields_def.push((*attr, dummy_field_type, mutable));
annotation
}
ast::ExprKind::Subscript { value, slice, .. }
if matches!(
&value.node,
ast::ExprKind::Name { id, .. } if core_config.kernel_ann.map_or(false, |c| id == &c.into())
) =>
{
(slice, true)
}
_ if core_config.kernel_ann.is_none() => (annotation, true),
_ => continue, // ignore fields annotated otherwise
};
class_fields_def.push((*attr, dummy_field_type, mutable));
// Supporting Class Attributes
Some(boxed_expr) => {
// Class attributes are set as immutable regardless
let (annotation, _) = match &annotation.node {
ast::ExprKind::Subscript { slice, .. } => (slice, false),
_ if core_config.kernel_ann.is_none() => (annotation, false),
_ => continue,
};
match &**boxed_expr {
ast::Located {location: _, custom: (), node: ast::ExprKind::Constant { value: v, kind: _ }} => {
// Restricting the types allowed to be defined as class attributes
match v {
ast::Constant::Bool(_) | ast::Constant::Str(_) | ast::Constant::Int(_) | ast::Constant::Float(_) => {}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
class_attributes_def.push((*attr, dummy_field_type, v.clone()));
}
_ => {
return Err(HashSet::from([
format!(
"unsupported statement in class definition body (at {})",
b.location
),
]))
}
}
annotation
}
};
let parsed_annotation = parse_ast_to_type_annotation_kinds(
class_resolver,
temp_def_list,
@ -1384,7 +1435,14 @@ impl TopLevelComposer {
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> {
let TopLevelDef::Class {
object_id, ancestors, fields, methods, resolver, type_vars, ..
object_id,
ancestors,
fields,
attributes,
methods,
resolver,
type_vars,
..
} = class_def
else {
unreachable!("here must be class def ast")
@ -1393,10 +1451,11 @@ impl TopLevelComposer {
_class_id,
class_ancestor_def,
class_fields_def,
class_attribute_def,
class_methods_def,
_class_type_vars_def,
_class_resolver,
) = (*object_id, ancestors, fields, methods, type_vars, resolver);
) = (*object_id, ancestors, fields, attributes, methods, type_vars, resolver);
// since when this function is called, the ancestors of the direct parent
// are supposed to be already handled, so we only need to deal with the direct parent
@ -1407,7 +1466,7 @@ impl TopLevelComposer {
let base = temp_def_list.get(id.0).unwrap();
let base = base.read();
let TopLevelDef::Class { methods, fields, .. } = &*base else {
let TopLevelDef::Class { methods, fields, attributes, .. } = &*base else {
unreachable!("must be top level class def")
};
@ -1449,7 +1508,7 @@ impl TopLevelComposer {
}
}
// use the new_child_methods to replace all the elements in `class_methods_def`
class_methods_def.drain(..);
class_methods_def.clear();
class_methods_def.extend(new_child_methods);
// handle class fields
@ -1459,7 +1518,9 @@ impl TopLevelComposer {
let to_be_added = (*anc_field_name, *anc_field_ty, *mutable);
// find if there is a fields with the same name in the child class
for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name {
if class_field_name == anc_field_name
|| attributes.iter().any(|f| f.0 == *class_field_name)
{
return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes"
)]));
@ -1467,14 +1528,33 @@ impl TopLevelComposer {
}
new_child_fields.push(to_be_added);
}
// handle class attributes
let mut new_child_attributes: Vec<(StrRef, Type, ast::Constant)> = Vec::new();
for (anc_attr_name, anc_attr_ty, attr_value) in attributes {
let to_be_added = (*anc_attr_name, *anc_attr_ty, attr_value.clone());
// find if there is a attribute with the same name in the child class
for (class_attr_name, ..) in &*class_attribute_def {
if class_attr_name == anc_attr_name
|| fields.iter().any(|f| f.0 == *class_attr_name)
{
return Err(HashSet::from([format!(
"attribute `{class_attr_name}` has already declared in the ancestor classes"
)]));
}
}
new_child_attributes.push(to_be_added);
}
for (class_field_name, class_field_ty, mutable) in &*class_fields_def {
if !is_override.contains(class_field_name) {
new_child_fields.push((*class_field_name, *class_field_ty, *mutable));
}
}
class_fields_def.drain(..);
class_fields_def.clear();
class_fields_def.extend(new_child_fields);
class_attribute_def.clear();
class_attribute_def.extend(new_child_attributes);
Ok(())
}

View File

@ -474,6 +474,7 @@ impl TopLevelComposer {
object_id: obj_id,
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
constructor,

View File

@ -103,6 +103,10 @@ pub enum TopLevelDef {
///
/// Name and type is mutable.
fields: Vec<(StrRef, Type, bool)>,
/// Class Attributes.
///
/// Name, type, value.
attributes: Vec<(StrRef, Type, ast::Constant)>,
/// Class methods, pointing to the corresponding function definition.
methods: Vec<(StrRef, Type, DefinitionId)>,
/// Ancestor classes, including itself.

View File

@ -5,7 +5,7 @@ expression: res_vec
[
"Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n",
"Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(239)]\n}\n",
"Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(240)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n",

View File

@ -7,7 +7,7 @@ expression: res_vec
"Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar228]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar228\"]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B[typevar229]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar229\"]\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n",

View File

@ -5,8 +5,8 @@ expression: res_vec
[
"Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(241)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(246)]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(242)]\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",

View File

@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs
expression: res_vec
---
[
"Class {\nname: \"A\",\nancestors: [\"A[typevar227, typevar228]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar227\", \"typevar228\"]\n}\n",
"Class {\nname: \"A\",\nancestors: [\"A[typevar228, typevar229]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar228\", \"typevar229\"]\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n",

View File

@ -6,12 +6,12 @@ expression: res_vec
"Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n",
"Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(248)]\n}\n",
"Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n",
"Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(255)]\n}\n",
"Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(256)]\n}\n",
]

View File

@ -470,6 +470,7 @@ pub fn get_type_from_type_annotation_kinds(
}
result
};
// Class Attributes keep a copy with Class Definition and are not added to objects
let mut tobj_fields = methods
.iter()
.map(|(name, ty, _)| {

View File

@ -34,6 +34,7 @@ pub enum TypeErrorKind {
},
RequiresTypeAnn,
PolymorphicFunctionPointer,
NoSuchAttribute(RecordKey, Type),
}
#[derive(Debug, Clone)]
@ -156,6 +157,10 @@ impl<'a> Display for DisplayTypeError<'a> {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` field/method does not exist")
}
NoSuchAttribute(name, t) => {
let t = self.unifier.stringify_with_notes(*t, &mut notes);
write!(f, "`{t}::{name}` is not a class attribute")
}
TupleIndexOutOfBounds { index, len } => {
write!(
f,

View File

@ -6,6 +6,7 @@ use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::toplevel::TopLevelDef;
use crate::{
symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{
@ -820,6 +821,155 @@ impl<'a> Inferencer<'a> {
})
}
/// Fold an ndarray `shape` argument. This function aims to fold `shape` arguments like that of
/// <https://numpy.org/doc/stable/reference/generated/numpy.zeros.html> (for `np_zeros`).
///
/// Arguments:
/// * `id` - The name of the function of the function call this `shape` argument is in. Used for error reporting.
/// * `arg_index` - The position (0-based) of this argument in the function call. Used for error reporting.
/// * `shape_expr` - [`Located<ExprKind>`] of the input argument.
///
/// On success, it returns a tuple of
/// 1) the `ndims` value inferred from the input `shape`,
/// 2) and the elaborated expression. Like what other fold functions of [`Inferencer`] would normally return.
fn fold_numpy_function_call_shape_argument(
&mut self,
id: StrRef,
arg_index: usize,
shape_expr: Located<ExprKind>,
) -> Result<(u64, ast::Expr<Option<Type>>), HashSet<String>> {
/*
### Further explanation
As said, this function aims to fold `shape` arguments, but this is *not* trivial.
The root of the issue is that `nac3core` has to deduce the `ndims`
of the created (for in the case of `np_zeros`) ndarray statically - i.e., during inference time.
There are three types of valid input to `shape`:
1. A python `List` (all `int32s`); e.g., `np_zeros([600, 800, 3])`
2. A python `Tuple` (all `int32s`); e.g., `np_zeros((600, 800, 3))`
3. An `int32`; e.g., `np_zeros(256)` - this is functionally equivalent to `np_zeros([256])`
For 2. and 3., `ndims` can be deduce immediately from the inferred type of the input:
- For 2. `ndims` is simply the number of elements found in [`TypeEnum::TTuple`] after typechecking the `shape` argument.
- For 3. `ndims` is simply 1.
For 1., `ndims` is supposedly the length of the input list. However, the length of the input list
is a runtime property. Therefore (as a hack) we resort to analyzing the argument expression [`ExprKind::List`]
itself to extract the input list length statically.
This implies that the user could only write:
```python
my_rgba_image = np_zeros([600, 800, 4])
# the shape argument is directly written as a list literal.
# and `nac3core` could therefore tell that ndims is `3` by
# looking at the raw AST expression itself.
```
But not:
```python
my_image_dimension = [600, 800, 4]
mystery_function_that_mutates_my_list(my_image_dimension)
my_image = np_zeros(my_image_dimension)
# what is the length now? what is `ndims`?
# it is *basically impossible* to generally determine the
# length of `my_image_dimension` statically for `ndims`!!
```
*/
// Auxillary details for error reporting.
// Predefined here because `shape_expr` will be moved when doing `fold_expr`
let shape_expr_name = shape_expr.node.name();
let shape_location = shape_expr.location;
// Fold `shape`
let shape = self.fold_expr(shape_expr)?;
let shape_ty = shape.custom.unwrap(); // The inferred type of `shape`
// Check `shape_ty` to see if its a list of int32s, a tuple of int32s, or just int32.
// Otherwise throw an error as that would mean the user wrote an ill-typed `shape_expr`.
//
// Here, we also take the opportunity to deduce `ndims` statically for 2. and 3.
let shape_ty_enum = &*self.unifier.get_ty(shape_ty);
let ndims = match shape_ty_enum {
TypeEnum::TList { ty } => {
// Handle 1. A list of int32s
// Typecheck
self.unifier.unify(*ty, self.primitives.int32).map_err(|err| {
HashSet::from([err
.at(Some(shape_location))
.to_display(self.unifier)
.to_string()])
})?;
// Special handling for (1. A python `List` (all `int32s`)).
// Read the doc above this function to see what is going on here.
if let ExprKind::List { elts, .. } = &shape.node {
// The user wrote a List literal as the input argument
elts.len() as u64
} else {
// This means the user is passing an expression of type `List`,
// but it is done so indirectly (like putting a variable referencing a `List`)
// rather than writing a List literal. We need to report an error.
return Err(HashSet::from([
format!(
"Expected List (must be a literal)/Tuple/int32 for argument {arg_num} of {id} at {shape_location}. \
There, you are passing a value of type List as the argument. \
However, this argument is special - you must only supply this argument with a List literal. \
On the other hand, you may instead pass in a tuple, and there would be no such restriction.",
arg_num = arg_index + 1
)
]));
}
}
TypeEnum::TTuple { ty: tuple_element_types } => {
// Handle 2. A tuple of int32s
// Typecheck
// The expected type is just the tuple but with all its elements being int32.
let expected_ty = self.unifier.add_ty(TypeEnum::TTuple {
ty: tuple_element_types.iter().map(|_| self.primitives.int32).collect_vec(),
});
self.unifier.unify(shape_ty, expected_ty).map_err(|err| {
HashSet::from([err
.at(Some(shape_location))
.to_display(self.unifier)
.to_string()])
})?;
// `ndims` can be deduced statically from the inferred Tuple type.
tuple_element_types.len() as u64
}
TypeEnum::TObj { .. } => {
// Handle 3. An integer (generalized as [`TypeEnum::TObj`])
// Typecheck
self.unify(self.primitives.int32, shape_ty, &shape_location)?;
// Deduce `ndims`
1
}
_ => {
// The user wrote an ill-typed `shape_expr`,
// so throw an error.
let shape_ty_str = self.unifier.stringify(shape_ty);
return report_error(
format!(
"Expected List (must be a literal)/Tuple/integer for first argument of {id}, got {shape_expr_name} of type {shape_ty_str}",
)
.as_str(),
shape_location,
);
}
};
Ok((ndims, shape))
}
/// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise
/// returns [`None`].
fn try_fold_special_call(
@ -1147,25 +1297,15 @@ impl<'a> Inferencer<'a> {
}));
}
// 1-argument ndarray n-dimensional creation functions
// 1-argument ndarray n-dimensional factory functions
if ["np_ndarray".into(), "np_empty".into(), "np_zeros".into(), "np_ones".into()]
.contains(id)
&& args.len() == 1
{
let ExprKind::List { elts, .. } = &args[0].node else {
return report_error(
format!(
"Expected List literal for first argument of {id}, got {}",
args[0].node.name()
)
.as_str(),
args[0].location,
);
};
let shape_expr = args.remove(0);
let (ndims, shape) =
self.fold_numpy_function_call_shape_argument(*id, 0, shape_expr)?; // Special handling the `shape`
let ndims = elts.len() as u64;
let arg0 = self.fold_expr(args.remove(0))?;
let ndims = self.unifier.get_fresh_literal(vec![SymbolValue::U64(ndims)], None);
let ret = make_ndarray_ty(
self.unifier,
@ -1176,7 +1316,7 @@ impl<'a> Inferencer<'a> {
let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: vec![FuncArg {
name: "shape".into(),
ty: arg0.custom.unwrap(),
ty: shape.custom.unwrap(),
default_value: None,
}],
ret,
@ -1192,7 +1332,7 @@ impl<'a> Inferencer<'a> {
location: func.location,
node: ExprKind::Name { id: *id, ctx: *ctx },
}),
args: vec![arg0],
args: vec![shape],
keywords: vec![],
},
}));
@ -1441,6 +1581,24 @@ impl<'a> Inferencer<'a> {
Ok(self.unifier.add_ty(TypeEnum::TTuple { ty }))
}
/// Checks for non-class attributes
fn infer_general_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
attr: StrRef,
ctx: ExprContext,
) -> InferenceResult {
let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once((
attr.into(),
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
))
.collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty)
}
fn infer_attribute(
&mut self,
value: &ast::Expr<Option<Type>>,
@ -1448,31 +1606,72 @@ impl<'a> Inferencer<'a> {
ctx: ExprContext,
) -> InferenceResult {
let ty = value.custom.unwrap();
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
if let TypeEnum::TObj { obj_id, fields, .. } = &*self.unifier.get_ty(ty) {
// just a fast path
match (fields.get(&attr), ctx == ExprContext::Store) {
(Some((ty, true)), _) | (Some((ty, false)), false) => Ok(*ty),
(Some((_, false)), true) => {
report_error(&format!("Field `{attr}` is immutable"), value.location)
}
(None, _) => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
(None, mutable) => {
// Check whether it is a class attribute
let defs = self.top_level.definitions.read();
let result = {
if let TopLevelDef::Class { attributes, .. } = &*defs[obj_id.0].read() {
attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.1);
}
None
})
} else {
None
}
};
match result {
Some(res) if !mutable => Ok(res),
Some(_) => report_error(
&format!("Class Attribute `{attr}` is immutable"),
value.location,
),
None => {
let t = self.unifier.stringify(ty);
report_error(
&format!("`{t}::{attr}` field/method does not exist"),
value.location,
)
}
}
}
}
} else if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(ty) {
// Access Class Attributes of classes with __init__ function using Class names e.g. Foo.ATTR1
let result = {
self.top_level.definitions.read().iter().find_map(|def| {
if let Some(rear_guard) = def.try_read() {
if let TopLevelDef::Class { name, attributes, .. } = &*rear_guard {
if name.to_string() == self.unifier.stringify(sign.ret) {
return attributes.iter().find_map(|f| {
if f.0 == attr {
return Some(f.clone().1);
}
None
});
}
}
}
None
})
};
match result {
Some(f) if ctx != ExprContext::Store => Ok(f),
Some(_) => {
report_error(&format!("Class Attribute `{attr}` is immutable"), value.location)
}
None => self.infer_general_attribute(value, attr, ctx),
}
} else {
let attr_ty = self.unifier.get_dummy_var().ty;
let fields = once((
attr.into(),
RecordField::new(attr_ty, ctx == ExprContext::Store, Some(value.location)),
))
.collect();
let record = self.unifier.add_record(fields);
self.constrain(value.custom.unwrap(), record, &value.location)?;
Ok(attr_ty)
self.infer_general_attribute(value, attr, ctx)
}
}

View File

@ -289,6 +289,7 @@ impl TestEnvironment {
object_id: DefinitionId(i),
type_vars: Vec::default(),
fields: Vec::default(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -331,6 +332,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 1),
type_vars: vec![tvar.ty],
fields: [("a".into(), tvar.ty, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -365,6 +367,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 2),
type_vars: Vec::default(),
fields: [("a".into(), int32, true), ("b".into(), fun, true)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,
@ -393,6 +396,7 @@ impl TestEnvironment {
object_id: DefinitionId(defs + 3),
type_vars: Vec::default(),
fields: [("a".into(), bool, true), ("b".into(), fun, false)].into(),
attributes: Vec::default(),
methods: Vec::default(),
ancestors: Vec::default(),
resolver: None,

View File

@ -71,17 +71,44 @@ def output_ndarray_float_2(n: ndarray[float, Literal[2]]):
def consume_ndarray_1(n: ndarray[float, Literal[1]]):
pass
def consume_ndarray_2(n: ndarray[float, Literal[2]]):
pass
def test_ndarray_ctor():
n: ndarray[float, Literal[1]] = np_ndarray([1])
consume_ndarray_1(n)
def test_ndarray_empty():
n: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n)
n1: ndarray[float, 1] = np_empty([1])
consume_ndarray_1(n1)
n2: ndarray[float, 1] = np_empty(10)
consume_ndarray_1(n2)
n3: ndarray[float, 1] = np_empty((2,))
consume_ndarray_1(n3)
n4: ndarray[float, 2] = np_empty((4, 4))
consume_ndarray_2(n4)
dim4 = (5, 2)
n5: ndarray[float, 2] = np_empty(dim4)
consume_ndarray_2(n5)
def test_ndarray_zeros():
n: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n)
n1: ndarray[float, 1] = np_zeros([1])
output_ndarray_float_1(n1)
k = 3 + int32(n1[0]) # to test variable shape inputs
n2: ndarray[float, 1] = np_zeros(k * k)
output_ndarray_float_1(n2)
n3: ndarray[float, 1] = np_zeros((k * 2,))
output_ndarray_float_1(n3)
dim4 = (3, 2 * k)
n4: ndarray[float, 2] = np_zeros(dim4)
output_ndarray_float_2(n4)
def test_ndarray_ones():
n: ndarray[float, 1] = np_ones([1])