diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 06d473ff..8d05c6d2 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -389,7 +389,7 @@ impl<'a> Fold<()> for Inferencer<'a> { } ast::StmtKind::Assign { targets, value, .. } => { for target in targets { - self.unify(target.custom.unwrap(), value.custom.unwrap(), &target.location)?; + self.fold_assign(target, value)?; } } ast::StmtKind::Raise { exc, cause, .. } => { @@ -2159,4 +2159,58 @@ impl<'a> Inferencer<'a> { self.constrain(body.custom.unwrap(), orelse.custom.unwrap(), &body.location)?; Ok(body.custom.unwrap()) } + + fn fold_assign( + &mut self, + target: &ast::Expr>, + value: &ast::Expr>, + ) -> Result<(), HashSet> { + let target_ty = target.custom.unwrap(); + let value_ty = value.custom.unwrap(); + + match (&target.node, &*self.unifier.get_ty(target_ty)) { + (ExprKind::Subscript { .. }, TypeEnum::TObj { obj_id: target_obj_id, .. }) + if *target_obj_id == self.primitives.ndarray.obj_id(self.unifier).unwrap() => + { + // Pattern match expressions like `my_ndarray[slices] = value`. + // TODO: `(my_ndarray[slices1], my_ndarray[slices2]) = (value1, value2)` are not supported for now. + + // Suppose `my_ndarray` has type `ndarray[target_dtype, ndims]` + // value's type could be one of the following: + // Case 1. `target_dtype` + // Case 2. `ndarray[target_dtype, ?]` + // Case 3. list, tuple, iterables (TODO: NOT IMPLEMENTED) + + let (target_dtype, _) = unpack_ndarray_var_tys(self.unifier, target_ty); + + // Typecheck `value_ty` + match &*self.unifier.get_ty(value_ty) { + TypeEnum::TObj { obj_id: value_obj_id, .. } + if *value_obj_id + == self.primitives.ndarray.obj_id(self.unifier).unwrap() => + { + // Case 2 + // - `dtype` of `target_ty` and `value_ty` must unify. + // - `ndims` of `value_ty` is ignored. + let (value_dtype, _) = unpack_ndarray_var_tys(self.unifier, value_ty); + + self.unify(target_dtype, value_dtype, &target.location)?; + } + _ => { + // If `value_ty` is not an ndarray, simply typecheck as through it has to be Case 1. + self.unify(target_dtype, value_ty, &target.location)?; + } + } + } + _ => { + // To handle + // - variable assignments `target = value` + // - and attribute assignments `target.my_attr = value` + // + // For these cases in nac3core, types of LHS and RHS must unify + self.unify(target_ty, value_ty, &target.location)?; + } + } + Ok(()) + } }