hm-inference #6
|
@ -30,6 +30,15 @@ pub struct Inferencer<'a> {
|
||||||
primitives: &'a PrimitiveStore,
|
primitives: &'a PrimitiveStore,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct NaiveFolder();
|
||||||
|
impl fold::Fold<()> for NaiveFolder {
|
||||||
|
type TargetU = Option<Type>;
|
||||||
|
type Error = String;
|
||||||
|
fn map_user(&mut self, _: ()) -> Result<Self::TargetU, Self::Error> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = String;
|
type Error = String;
|
||||||
|
@ -38,6 +47,66 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn fold_stmt(&mut self, node: ast::Stmt<()>) -> Result<ast::Stmt<Self::TargetU>, Self::Error> {
|
||||||
|
let stmt = match node.node {
|
||||||
|
// we don't want fold over type annotation
|
||||||
|
ast::StmtKind::AnnAssign {
|
||||||
|
target,
|
||||||
|
annotation,
|
||||||
|
value,
|
||||||
|
simple,
|
||||||
|
} => {
|
||||||
|
let target = Box::new(fold::fold_expr(self, *target)?);
|
||||||
|
let value = if let Some(v) = value {
|
||||||
|
let ty = Box::new(fold::fold_expr(self, *v)?);
|
||||||
|
self.unifier
|
||||||
|
.unify(target.custom.unwrap(), ty.custom.unwrap())?;
|
||||||
|
Some(ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let annotation_type = self
|
||||||
|
.resolver
|
||||||
|
.parse_type_name(annotation.as_ref())
|
||||||
|
.ok_or_else(|| "cannot parse type name".to_string())?;
|
||||||
|
self.unifier.unify(annotation_type, target.custom.unwrap())?;
|
||||||
|
let annotation = Box::new(NaiveFolder().fold_expr(*annotation)?);
|
||||||
|
Located {
|
||||||
|
location: node.location,
|
||||||
|
custom: None,
|
||||||
|
node: ast::StmtKind::AnnAssign {
|
||||||
|
target,
|
||||||
|
annotation,
|
||||||
|
value,
|
||||||
|
simple,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => fold::fold_stmt(self, node)?,
|
||||||
|
};
|
||||||
|
match &stmt.node {
|
||||||
|
ast::StmtKind::For { target, iter, .. } => {
|
||||||
|
let list = self.unifier.add_ty(TypeEnum::TList {
|
||||||
|
ty: target.custom.unwrap(),
|
||||||
|
});
|
||||||
|
self.unifier.unify(list, iter.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
ast::StmtKind::If { test, .. } | ast::StmtKind::While { test, .. } => {
|
||||||
|
self.unifier
|
||||||
|
.unify(test.custom.unwrap(), self.primitives.bool)?;
|
||||||
|
}
|
||||||
|
ast::StmtKind::Assign { targets, value, .. } => {
|
||||||
|
for target in targets.iter() {
|
||||||
|
self.unifier
|
||||||
|
.unify(target.custom.unwrap(), value.custom.unwrap())?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ast::StmtKind::AnnAssign { .. } => {}
|
||||||
|
_ => return Err("Unsupported statement type".to_string())
|
||||||
|
};
|
||||||
|
Ok(stmt)
|
||||||
|
}
|
||||||
|
|
||||||
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
|
||||||
let expr = match node.node {
|
let expr = match node.node {
|
||||||
ast::ExprKind::Call {
|
ast::ExprKind::Call {
|
||||||
|
|
Loading…
Reference in New Issue