Compare commits

...

3 Commits

Author SHA1 Message Date
e614dd4257 core/type_inferencer: Fix location of unary/compare expressions
Codegen uses this location information to determine the CallId, and if
a function call is the operand of a unary expression or left-hand
operand of a compare expression, codegen will use the type of the
operator expression rather than the actual operand type.
2024-04-05 15:42:10 +08:00
937a8b9698 core/magic_methods: Fix type of unary ops with primitive types 2024-04-05 13:23:08 +08:00
876ad6c59c core/type_inferencer: Include location info if inferencer fails 2024-04-05 13:22:35 +08:00
2 changed files with 36 additions and 23 deletions

View File

@ -476,10 +476,24 @@ pub fn typeof_unaryop(
return Err("The truth value of an array with more than one element is ambiguous".to_string()) return Err("The truth value of an array with more than one element is ambiguous".to_string())
} }
Ok(if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) { Ok(match *op {
Some(operand) Unaryop::Not => {
} else { match operand.obj_id(unifier) {
None Some(v) if v == PRIMITIVE_DEF_IDS.ndarray => Some(operand),
Some(_) => Some(primitives.bool),
_ => None
}
}
Unaryop::Invert
| Unaryop::UAdd
| Unaryop::USub => {
if operand.obj_id(unifier).is_some_and(|id| PRIMITIVE_DEF_IDS.iter().any(|prim_id| id == prim_id)) {
Some(operand)
} else {
None
}
}
}) })
} }

View File

@ -4,7 +4,7 @@ use std::iter::once;
use std::{cell::RefCell, sync::Arc}; use std::{cell::RefCell, sync::Arc};
use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier, VarMap};
use super::{magic_methods::*, typedef::CallId}; use super::{magic_methods::*, type_error::TypeError, typedef::CallId};
use crate::{ use crate::{
symbol_resolver::{SymbolResolver, SymbolValue}, symbol_resolver::{SymbolResolver, SymbolValue},
toplevel::{ toplevel::{
@ -553,7 +553,7 @@ impl<'a> Fold<()> for Inferencer<'a> {
Some(self.infer_unary_ops(expr.location, op, operand)?) Some(self.infer_unary_ops(expr.location, op, operand)?)
} }
ExprKind::Compare { left, ops, comparators } => { ExprKind::Compare { left, ops, comparators } => {
Some(self.infer_compare(left, ops, comparators)?) Some(self.infer_compare(expr.location, left, ops, comparators)?)
} }
ExprKind::Subscript { value, slice, ctx, .. } => { ExprKind::Subscript { value, slice, ctx, .. } => {
Some(self.infer_subscript(value.as_ref(), slice.as_ref(), ctx)?) Some(self.infer_subscript(value.as_ref(), slice.as_ref(), ctx)?)
@ -628,7 +628,14 @@ impl<'a> Inferencer<'a> {
loc: Some(location), loc: Some(location),
}; };
if let Some(ret) = ret { if let Some(ret) = ret {
self.unifier.unify(sign.ret, ret).unwrap(); self.unifier.unify(sign.ret, ret)
.map_err(|err| {
format!("Cannot unify {} <: {} - {:?}",
self.unifier.stringify(sign.ret),
self.unifier.stringify(ret),
TypeError::new(err.kind, Some(location)))
})
.unwrap();
} }
let required: Vec<_> = sign let required: Vec<_> = sign
.args .args
@ -1262,11 +1269,12 @@ impl<'a> Inferencer<'a> {
operand.custom.unwrap(), operand.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?; ).map_err(|e| HashSet::from([format!("{e} (at {location})")]))?;
self.build_method_call(operand.location, method, operand.custom.unwrap(), vec![], ret) self.build_method_call(location, method, operand.custom.unwrap(), vec![], ret)
} }
fn infer_compare( fn infer_compare(
&mut self, &mut self,
location: Location,
left: &ast::Expr<Option<Type>>, left: &ast::Expr<Option<Type>>,
ops: &[ast::Cmpop], ops: &[ast::Cmpop],
comparators: &[ast::Expr<Option<Type>>], comparators: &[ast::Expr<Option<Type>>],
@ -1275,6 +1283,7 @@ impl<'a> Inferencer<'a> {
return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")])) return Err(HashSet::from([String::from("Comparator chaining with ndarray types not supported")]))
} }
let mut res = None;
for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) { for (a, b, c) in izip!(once(left).chain(comparators), comparators, ops) {
let method = comparison_name(c) let method = comparison_name(c)
.ok_or_else(|| HashSet::from([ .ok_or_else(|| HashSet::from([
@ -1289,27 +1298,17 @@ impl<'a> Inferencer<'a> {
a.custom.unwrap(), a.custom.unwrap(),
b.custom.unwrap(), b.custom.unwrap(),
).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?; ).map_err(|e| HashSet::from([format!("{e} (at {})", b.location)]))?;
self.build_method_call( res.replace(self.build_method_call(
a.location, location,
method, method,
a.custom.unwrap(), a.custom.unwrap(),
vec![b.custom.unwrap()], vec![b.custom.unwrap()],
ret, ret,
)?; )?);
} }
let res_lhs = comparators.iter().rev().nth(1).unwrap_or(left); Ok(res.unwrap())
let res_rhs = comparators.iter().rev().nth(0).unwrap();
let res_op = ops.iter().rev().nth(0).unwrap();
Ok(typeof_cmpop(
self.unifier,
self.primitives,
res_op,
res_lhs.custom.unwrap(),
res_rhs.custom.unwrap(),
).unwrap().unwrap())
} }
/// Infers the type of a subscript expression on an `ndarray`. /// Infers the type of a subscript expression on an `ndarray`.