add wrapper, now can fold from Expr<()> to Expr<Option<Type>>; fix slice; some more testing

refactor_anto
CrescentonC 2021-07-16 11:28:32 +08:00
parent f33b3d3482
commit be512985a7
1 changed files with 125 additions and 19 deletions

View File

@ -8,7 +8,20 @@ use crate::typecheck::primitives;
use rustpython_parser::ast;
use rustpython_parser::ast::fold::Fold;
// REVIEW: direct impl fold trait on InferenceContext
struct Premapper;
impl ast::fold::Fold<()> for Premapper {
type TargetU = Option<Type>;
type Error = String;
fn map_user(&mut self, _user: ()) -> Result<Self::TargetU, Self::Error> {
Ok(None)
}
fn fold_expr(&mut self, node: ast::Expr<()>) -> Result<ast::Expr<Self::TargetU>, Self::Error> {
ast::fold::fold_expr(self, node)
}
}
impl<'a> ast::fold::Fold<Option<Type>> for InferenceContext<'a> {
type TargetU = Option<Type>;
type Error = String;
@ -43,6 +56,7 @@ impl<'a> ast::fold::Fold<Option<Type>> for InferenceContext<'a> {
ast::ExprKind::Subscript {value, slice, ctx: _} => self.infer_subscript(value, slice),
ast::ExprKind::IfExp {test, body, orelse} => self.infer_if_expr(test, body, orelse),
ast::ExprKind::ListComp {elt, generators} => self.infer_list_comprehesion(elt, generators),
ast::ExprKind::Slice { .. } => Ok(None), // special handling for slice, which is supported
_ => Err("not supported yet".into())
}?,
location: expr.location,
@ -408,32 +422,39 @@ impl<'a> InferenceContext<'a> {
}
}
pub struct ExpressionInferencer<'a> {
pub ctx: InferenceContext<'a>
}
impl<'a> ExpressionInferencer<'a> {
pub fn fold_expr(&mut self, expr: ast::Expr) -> Result<ast::Expr<Option<Type>>, String> {
let expr = Premapper.fold_expr(expr)?;
self.ctx.fold_expr(expr)
}
}
pub mod test {
use crate::typecheck::{symbol_resolver::SymbolResolver, typedef::*, symbol_resolver::*, location::*};
use rustpython_parser::ast::{self, Expr, fold::Fold};
use super::*;
pub fn new_ctx<'a>() -> InferenceContext<'a>{
pub fn new_ctx<'a>() -> ExpressionInferencer<'a> {
struct S;
impl SymbolResolver for S {
fn get_symbol_location(&self, _str: &str) -> Option<Location> { None }
fn get_symbol_type(&self, _str: &str) -> Option<SymbolType> { None }
fn get_symbol_value(&self, _str: &str) -> Option<SymbolValue> { None }
}
InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3))
ExpressionInferencer {ctx: InferenceContext::new(primitives::basic_ctx(), Box::new(S{}), FileID(3))}
}
#[test]
fn test_i32() {
let mut inferencer = new_ctx();
let ast: Expr<Option<Type>> = Expr {
let ast: Expr = Expr {
location: ast::Location::new(0, 0),
custom: None,
custom: (),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(123.into()),
kind: None
@ -445,7 +466,7 @@ pub mod test {
new_ast,
Ok(ast::Expr {
location: ast::Location::new(0, 0),
custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)),
custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(123.into()),
kind: None
@ -461,9 +482,9 @@ pub mod test {
let location = ast::Location::new(0, 0);
let num: i64 = 99999999999;
let ast: Expr<Option<Type>> = Expr {
let ast: Expr = Expr {
location: location,
custom: None,
custom: (),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(num.into()),
kind: None,
@ -476,7 +497,7 @@ pub mod test {
new_ast,
Expr {
location: location,
custom: Some(inferencer.get_primitive(primitives::INT64_TYPE)),
custom: Some(inferencer.ctx.get_primitive(primitives::INT64_TYPE)),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(num.into()),
kind: None,
@ -485,20 +506,67 @@ pub mod test {
);
}
#[test]
fn test_tuple() {
let mut inferencer = new_ctx();
let i32_t = inferencer.ctx.get_primitive(primitives::INT32_TYPE);
let float_t = inferencer.ctx.get_primitive(primitives::FLOAT_TYPE);
let ast = rustpython_parser::parser::parse_expression("(123, 123.123, 999999999)").unwrap();
let loc = ast.location.clone();
let folded = inferencer.fold_expr(ast).unwrap();
assert_eq!(
folded,
ast::Expr {
location: loc,
custom: Some(TypeEnum::ParametricType(primitives::TUPLE_TYPE, vec![i32_t.clone().into(), float_t.clone().into(), i32_t.clone().into()]).into()),
node: ast::ExprKind::Tuple {
ctx: ast::ExprContext::Load,
elts: vec![
ast::Expr {
location: ast::Location::new(1, 2),
custom: Some(i32_t.clone()),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(123.into()),
kind: None
}
},
ast::Expr {
location: ast::Location::new(1, 7),
custom: Some(float_t.clone()),
node: ast::ExprKind::Constant {
value: ast::Constant::Float(123.123),
kind: None
}
},
ast::Expr {
location: ast::Location::new(1, 16),
custom: Some(i32_t.clone()),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(999999999.into()),
kind: None
}
},
]
}
}
);
}
#[test]
fn test_list() {
let mut inferencer = new_ctx();
let location = ast::Location::new(0, 0);
let ast: Expr<Option<Type>> = Expr {
let ast: Expr = Expr {
location,
custom: None,
custom: (),
node: ast::ExprKind::List {
ctx: ast::ExprContext::Load,
elts: vec![
Expr {
location,
custom: None,
custom: (),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(1.into()),
kind: None,
@ -507,7 +575,7 @@ pub mod test {
Expr {
location,
custom: None,
custom: (),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(2.into()),
kind: None,
@ -522,13 +590,13 @@ pub mod test {
new_ast,
Expr {
location,
custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.get_primitive(primitives::INT32_TYPE).into()]).into()),
custom: Some(TypeEnum::ParametricType(primitives::LIST_TYPE, vec![inferencer.ctx.get_primitive(primitives::INT32_TYPE).into()]).into()),
node: ast::ExprKind::List {
ctx: ast::ExprContext::Load,
elts: vec![
Expr {
location,
custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)),
custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)),
node: ast::ExprKind::Constant {
value: ast::Constant::Int(1.into()),
kind: None,
@ -537,7 +605,7 @@ pub mod test {
Expr {
location,
custom: Some(inferencer.get_primitive(primitives::INT32_TYPE)),
custom: Some(inferencer.ctx.get_primitive(primitives::INT32_TYPE)),
// custom: None,
node: ast::ExprKind::Constant {
value: ast::Constant::Int(2.into()),
@ -549,4 +617,42 @@ pub mod test {
}
);
}
#[test]
fn test_mix() {
let mut inf = new_ctx();
let ast1 = rustpython_parser::parser::parse_expression("False == [True or True, False][0]").unwrap();
let ast2 = rustpython_parser::parser::parse_expression("False == [True or True, False][0]").unwrap();
let ast3 = rustpython_parser::parser::parse_expression("1 < 2 < 3").unwrap();
let ast4 = rustpython_parser::parser::parse_expression("1 + [12312, 1231][0]").unwrap();
let ast5 = rustpython_parser::parser::parse_expression("not True").unwrap();
let ast6 = rustpython_parser::parser::parse_expression("[[1]][0][0]").unwrap();
let ast7 = rustpython_parser::parser::parse_expression("[[1]][0]").unwrap();
let ast8 = rustpython_parser::parser::parse_expression("[[(1, 2), (2, 3), (3, 4)], [(2, 4), (4, 6)]][0]").unwrap();
let ast9 = rustpython_parser::parser::parse_expression("[1, 2, 3, 4, 5][1: 2]").unwrap();
let ast10 = rustpython_parser::parser::parse_expression("4 if False and True else 8").unwrap();
let folded = inf.fold_expr(ast1).unwrap();
let folded_2 = Premapper.fold_expr(ast2).unwrap();
let folded_3 = inf.fold_expr(ast3).unwrap();
let folded_4 = inf.fold_expr(ast4).unwrap();
let folded_5 = inf.fold_expr(ast5).unwrap();
let folded_6 = inf.fold_expr(ast6).unwrap();
let folded_7 = inf.fold_expr(ast7).unwrap();
let folded_8 = inf.fold_expr(ast8).unwrap();
let folded_9 = inf.fold_expr(ast9).unwrap();
let folded_10 = inf.fold_expr(ast10).unwrap();
println!("{:?}", folded.custom);
println!("{:?}", folded_2.custom);
println!("{:?}", folded_3.custom);
println!("{:?}", folded_4.custom);
println!("{:?}", folded_5.custom);
println!("{:?}", folded_6.custom);
println!("{:?}", folded_7.custom);
println!("{:?}", folded_8.custom);
println!("{:?}", folded_9.custom);
println!("{:?}", folded_10.custom);
}
}