#! /usr/bin/env python """Generate Rust code from an ASDL description.""" import os import sys import textwrap import json from argparse import ArgumentParser from pathlib import Path import asdl TABSIZE = 4 AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n" builtin_type_mapping = { 'identifier': 'Ident', 'string': 'String', 'int': 'usize', 'constant': 'Constant', 'bool': 'bool', 'conversion_flag': 'ConversionFlag', } assert builtin_type_mapping.keys() == asdl.builtin_types def get_rust_type(name): """Return a string for the C name of the type. This function special cases the default types provided by asdl. """ if name in asdl.builtin_types: return builtin_type_mapping[name] else: return "".join(part.capitalize() for part in name.split("_")) def is_simple(sum): """Return True if a sum is a simple. A sum is simple if its types have no fields, e.g. unaryop = Invert | Not | UAdd | USub """ for t in sum.types: if t.fields: return False return True def asdl_of(name, obj): if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor): fields = ", ".join(map(str, obj.fields)) if fields: fields = "({})".format(fields) return "{}{}".format(name, fields) else: if is_simple(obj): types = " | ".join(type.name for type in obj.types) else: sep = "\n{}| ".format(" " * (len(name) + 1)) types = sep.join( asdl_of(type.name, type) for type in obj.types ) return "{} = {}".format(name, types) class EmitVisitor(asdl.VisitorBase): """Visit that emits lines""" def __init__(self, file): self.file = file self.identifiers = set() super(EmitVisitor, self).__init__() def emit_identifier(self, name): name = str(name) if name in self.identifiers: return self.emit("_Py_IDENTIFIER(%s);" % name, 0) self.identifiers.add(name) def emit(self, line, depth): if line: line = (" " * TABSIZE * depth) + line self.file.write(line + "\n") class TypeInfo: def __init__(self, name): self.name = name self.has_userdata = None self.children = set() self.boxed = False def __repr__(self): return f"" def determine_userdata(self, typeinfo, stack): if self.name in stack: return None stack.add(self.name) for child, child_seq in self.children: if child in asdl.builtin_types: continue childinfo = typeinfo[child] child_has_userdata = childinfo.determine_userdata(typeinfo, stack) if self.has_userdata is None and child_has_userdata is True: self.has_userdata = True stack.remove(self.name) return self.has_userdata class FindUserdataTypesVisitor(asdl.VisitorBase): def __init__(self, typeinfo): self.typeinfo = typeinfo super().__init__() def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) stack = set() for info in self.typeinfo.values(): info.determine_userdata(self.typeinfo, stack) def visitType(self, type): self.typeinfo[type.name] = TypeInfo(type.name) self.visit(type.value, type.name) def visitSum(self, sum, name): info = self.typeinfo[name] if is_simple(sum): info.has_userdata = False else: if len(sum.types) > 1: info.boxed = True if sum.attributes: # attributes means Located, which has the `custom: U` field info.has_userdata = True for variant in sum.types: self.add_children(name, variant.fields) def visitProduct(self, product, name): info = self.typeinfo[name] if product.attributes: # attributes means Located, which has the `custom: U` field info.has_userdata = True if len(product.fields) > 2: info.boxed = True self.add_children(name, product.fields) def add_children(self, name, fields): self.typeinfo[name].children.update((field.type, field.seq) for field in fields) def rust_field(field_name): if field_name == 'type': return 'type_' else: return field_name class TypeInfoEmitVisitor(EmitVisitor): def __init__(self, file, typeinfo): self.typeinfo = typeinfo super().__init__(file) def has_userdata(self, typ): return self.typeinfo[typ].has_userdata def get_generics(self, typ, *generics): if self.has_userdata(typ): return [f"<{g}>" for g in generics] else: return ["" for g in generics] class StructVisitor(TypeInfoEmitVisitor): """Visitor to generate typedefs for AST.""" def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): if is_simple(sum): self.simple_sum(sum, name, depth) else: self.sum_with_constructors(sum, name, depth) def emit_attrs(self, depth): self.emit("#[derive(Clone, Debug, PartialEq)]", depth) def simple_sum(self, sum, name, depth): rustname = get_rust_type(name) self.emit_attrs(depth) self.emit(f"pub enum {rustname} {{", depth) for variant in sum.types: self.emit(f"{variant.name},", depth + 1) self.emit("}", depth) self.emit("", depth) def sum_with_constructors(self, sum, name, depth): typeinfo = self.typeinfo[name] generics, generics_applied = self.get_generics(name, "U = ()", "U") enumname = rustname = get_rust_type(name) # all the attributes right now are for location, so if it has attrs we # can just wrap it in Located<> if sum.attributes: enumname = rustname + "Kind" self.emit_attrs(depth) self.emit(f"pub enum {enumname}{generics} {{", depth) for t in sum.types: self.visit(t, typeinfo, depth + 1) self.emit("}", depth) if sum.attributes: self.emit(f"pub type {rustname} = Located<{enumname}{generics_applied}, U>;", depth) self.emit("", depth) def visitConstructor(self, cons, parent, depth): if cons.fields: self.emit(f"{cons.name} {{", depth) for f in cons.fields: self.visit(f, parent, "", depth + 1) self.emit("},", depth) else: self.emit(f"{cons.name},", depth) def visitField(self, field, parent, vis, depth): typ = get_rust_type(field.type) fieldtype = self.typeinfo.get(field.type) if fieldtype and fieldtype.has_userdata: typ = f"{typ}" # don't box if we're doing Vec, but do box if we're doing Vec>> if fieldtype and fieldtype.boxed and (not field.seq or field.opt): typ = f"Box<{typ}>" if field.opt: typ = f"Option<{typ}>" if field.seq: typ = f"Vec<{typ}>" name = rust_field(field.name) self.emit(f"{vis}{name}: {typ},", depth) def visitProduct(self, product, name, depth): typeinfo = self.typeinfo[name] generics, generics_applied = self.get_generics(name, "U = ()", "U") dataname = rustname = get_rust_type(name) if product.attributes: dataname = rustname + "Data" self.emit_attrs(depth) self.emit(f"pub struct {dataname}{generics} {{", depth) for f in product.fields: self.visit(f, typeinfo, "pub ", depth + 1) self.emit("}", depth) if product.attributes: # attributes should just be location info self.emit(f"pub type {rustname} = Located<{dataname}{generics_applied}, U>;", depth); self.emit("", depth) class FoldTraitDefVisitor(TypeInfoEmitVisitor): def visitModule(self, mod, depth): self.emit("pub trait Fold {", depth) self.emit("type TargetU;", depth + 1) self.emit("type Error;", depth + 1) self.emit("fn map_user(&mut self, user: U) -> Result;", depth + 2) for dfn in mod.dfns: self.visit(dfn, depth + 2) self.emit("}", depth) def visitType(self, type, depth): name = type.name apply_u, apply_target_u = self.get_generics(name, "U", "Self::TargetU") enumname = get_rust_type(name) self.emit(f"fn fold_{name}(&mut self, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, Self::Error> {{", depth) self.emit(f"fold_{name}(self, node)", depth + 1) self.emit("}", depth) class FoldImplVisitor(TypeInfoEmitVisitor): def visitModule(self, mod, depth): self.emit("fn fold_located + ?Sized, T, MT>(folder: &mut F, node: Located, f: impl FnOnce(&mut F, T) -> Result) -> Result, F::Error> {", depth) self.emit("Ok(Located { custom: folder.map_user(node.custom)?, location: node.location, node: f(folder, node.node)? })", depth + 1) self.emit("}", depth) for dfn in mod.dfns: self.visit(dfn, depth) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU") enumname = get_rust_type(name) is_located = bool(sum.attributes) self.emit(f"impl Foldable for {enumname}{apply_t} {{", depth) self.emit(f"type Mapped = {enumname}{apply_u};", depth + 1) self.emit("fn fold + ?Sized>(self, folder: &mut F) -> Result {", depth + 1) self.emit(f"folder.fold_{name}(self)", depth + 2) self.emit("}", depth + 1) self.emit("}", depth) self.emit(f"pub fn fold_{name} + ?Sized>(#[allow(unused)] folder: &mut F, node: {enumname}{apply_u}) -> Result<{enumname}{apply_target_u}, F::Error> {{", depth) if is_located: self.emit("fold_located(folder, node, |folder, node| {", depth) enumname += "Kind" self.emit("match node {", depth + 1) for cons in sum.types: fields_pattern = self.make_pattern(cons.fields) self.emit(f"{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth + 2) self.gen_construction(f"{enumname}::{cons.name}", cons.fields, depth + 3) self.emit("}", depth + 2) self.emit("}", depth + 1) if is_located: self.emit("})", depth) self.emit("}", depth) def visitProduct(self, product, name, depth): apply_t, apply_u, apply_target_u = self.get_generics(name, "T", "U", "F::TargetU") structname = get_rust_type(name) is_located = bool(product.attributes) self.emit(f"impl Foldable for {structname}{apply_t} {{", depth) self.emit(f"type Mapped = {structname}{apply_u};", depth + 1) self.emit("fn fold + ?Sized>(self, folder: &mut F) -> Result {", depth + 1) self.emit(f"folder.fold_{name}(self)", depth + 2) self.emit("}", depth + 1) self.emit("}", depth) self.emit(f"pub fn fold_{name} + ?Sized>(#[allow(unused)] folder: &mut F, node: {structname}{apply_u}) -> Result<{structname}{apply_target_u}, F::Error> {{", depth) if is_located: self.emit("fold_located(folder, node, |folder, node| {", depth) structname += "Data" fields_pattern = self.make_pattern(product.fields) self.emit(f"let {structname} {{ {fields_pattern} }} = node;", depth + 1) self.gen_construction(structname, product.fields, depth + 1) if is_located: self.emit("})", depth) self.emit("}", depth) def make_pattern(self, fields): return ",".join(rust_field(f.name) for f in fields) def gen_construction(self, cons_path, fields, depth): self.emit(f"Ok({cons_path} {{", depth) for field in fields: name = rust_field(field.name) self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1) self.emit("})", depth) class FoldModuleVisitor(TypeInfoEmitVisitor): def visitModule(self, mod): depth = 0 self.emit('#[cfg(feature = "fold")]', depth) self.emit("pub mod fold {", depth) self.emit("use super::*;", depth + 1) self.emit("use crate::fold_helpers::Foldable;", depth + 1) FoldTraitDefVisitor(self.file, self.typeinfo).visit(mod, depth + 1) FoldImplVisitor(self.file, self.typeinfo).visit(mod, depth + 1) self.emit("}", depth) class ClassDefVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): for cons in sum.types: self.visit(cons, sum.attributes, depth) def visitConstructor(self, cons, attrs, depth): self.gen_classdef(cons.name, cons.fields, attrs, depth) def visitProduct(self, product, name, depth): self.gen_classdef(name, product.fields, product.attributes, depth) def gen_classdef(self, name, fields, attrs, depth): structname = "Node" + name self.emit(f'#[pyclass(module = "_ast", name = {json.dumps(name)}, base = "AstNode")]', depth) self.emit(f"struct {structname};", depth) self.emit("#[pyimpl(flags(HAS_DICT, BASETYPE))]", depth) self.emit(f"impl {structname} {{", depth) self.emit(f"#[extend_class]", depth + 1) self.emit("fn extend_class_with_fields(ctx: &PyContext, class: &PyTypeRef) {", depth + 1) fields = ",".join(f"ctx.new_str({json.dumps(f.name)})" for f in fields) self.emit(f'class.set_str_attr("_fields", ctx.new_list(vec![{fields}]));', depth + 2) attrs = ",".join(f"ctx.new_str({json.dumps(attr.name)})" for attr in attrs) self.emit(f'class.set_str_attr("_attributes", ctx.new_list(vec![{attrs}]));', depth + 2) self.emit("}", depth + 1) self.emit("}", depth) class ExtendModuleVisitor(EmitVisitor): def visitModule(self, mod): depth = 0 self.emit("pub fn extend_module_nodes(vm: &VirtualMachine, module: &PyObjectRef) {", depth) self.emit("extend_module!(vm, module, {", depth + 1) for dfn in mod.dfns: self.visit(dfn, depth + 2) self.emit("})", depth + 1) self.emit("}", depth) def visitType(self, type, depth): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): for cons in sum.types: self.visit(cons, depth) def visitConstructor(self, cons, depth): self.gen_extension(cons.name, depth) def visitProduct(self, product, name, depth): self.gen_extension(name, depth) def gen_extension(self, name, depth): self.emit(f"{json.dumps(name)} => Node{name}::make_class(&vm.ctx),", depth) class TraitImplVisitor(EmitVisitor): def visitModule(self, mod): for dfn in mod.dfns: self.visit(dfn) def visitType(self, type, depth=0): self.visit(type.value, type.name, depth) def visitSum(self, sum, name, depth): enumname = get_rust_type(name) if sum.attributes: enumname += "Kind" self.emit(f"impl NamedNode for ast::{enumname} {{", depth) self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1) self.emit("}", depth) self.emit(f"impl Node for ast::{enumname} {{", depth) self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1) self.emit("match self {", depth + 2) for variant in sum.types: self.constructor_to_object(variant, enumname, depth + 3) self.emit("}", depth + 2) self.emit("}", depth + 1) self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1) self.gen_sum_fromobj(sum, name, enumname, depth + 2) self.emit("}", depth + 1) self.emit("}", depth) def constructor_to_object(self, cons, enumname, depth): fields_pattern = self.make_pattern(cons.fields) self.emit(f"ast::{enumname}::{cons.name} {{ {fields_pattern} }} => {{", depth) self.make_node(cons.name, cons.fields, depth + 1) self.emit("}", depth) def visitProduct(self, product, name, depth): structname = get_rust_type(name) if product.attributes: structname += "Data" self.emit(f"impl NamedNode for ast::{structname} {{", depth) self.emit(f"const NAME: &'static str = {json.dumps(name)};", depth + 1) self.emit("}", depth) self.emit(f"impl Node for ast::{structname} {{", depth) self.emit("fn ast_to_object(self, _vm: &VirtualMachine) -> PyObjectRef {", depth + 1) fields_pattern = self.make_pattern(product.fields) self.emit(f"let ast::{structname} {{ {fields_pattern} }} = self;", depth + 2) self.make_node(name, product.fields, depth + 2) self.emit("}", depth + 1) self.emit("fn ast_from_object(_vm: &VirtualMachine, _object: PyObjectRef) -> PyResult {", depth + 1) self.gen_product_fromobj(product, name, structname, depth + 2) self.emit("}", depth + 1) self.emit("}", depth) def make_node(self, variant, fields, depth): lines = [] self.emit(f"let _node = AstNode.into_ref_with_type(_vm, Node{variant}::static_type().clone()).unwrap();", depth) if fields: self.emit("let _dict = _node.as_object().dict().unwrap();", depth) for f in fields: self.emit(f"_dict.set_item({json.dumps(f.name)}, {rust_field(f.name)}.ast_to_object(_vm), _vm).unwrap();", depth) self.emit("_node.into_object()", depth) def make_pattern(self, fields): return ",".join(rust_field(f.name) for f in fields) def gen_sum_fromobj(self, sum, sumname, enumname, depth): if sum.attributes: self.extract_location(sumname, depth) self.emit("let _cls = _object.class();", depth) self.emit("Ok(", depth) for cons in sum.types: self.emit(f"if _cls.is(Node{cons.name}::static_type()) {{", depth) self.gen_construction(f"{enumname}::{cons.name}", cons, sumname, depth + 1) self.emit("} else", depth) self.emit("{", depth) msg = f'format!("expected some sort of {sumname}, but got {{}}",_vm.to_repr(&_object)?)' self.emit(f"return Err(_vm.new_type_error({msg}));", depth + 1) self.emit("})", depth) def gen_product_fromobj(self, product, prodname, structname, depth): if product.attributes: self.extract_location(prodname, depth) self.emit("Ok(", depth) self.gen_construction(structname, product, prodname, depth + 1) self.emit(")", depth) def gen_construction(self, cons_path, cons, name, depth): self.emit(f"ast::{cons_path} {{", depth) for field in cons.fields: self.emit(f"{rust_field(field.name)}: {self.decode_field(field, name)},", depth + 1) self.emit("}", depth) def extract_location(self, typename, depth): row = self.decode_field(asdl.Field('int', 'lineno'), typename) column = self.decode_field(asdl.Field('int', 'col_offset'), typename) self.emit(f"let _location = ast::Location::new({row}, {column});", depth) def wrap_located_node(self, depth): self.emit(f"let node = ast::Located::new(_location, node);", depth) def decode_field(self, field, typename): name = json.dumps(field.name) if field.opt and not field.seq: return f"get_node_field_opt(_vm, &_object, {name})?.map(|obj| Node::ast_from_object(_vm, obj)).transpose()?" else: return f"Node::ast_from_object(_vm, get_node_field(_vm, &_object, {name}, {json.dumps(typename)})?)?" class ChainOfVisitors: def __init__(self, *visitors): self.visitors = visitors def visit(self, object): for v in self.visitors: v.visit(object) v.emit("", 0) def write_ast_def(mod, typeinfo, f): f.write('pub use crate::location::Location;\n') f.write('pub use crate::constant::*;\n') f.write('\n') f.write('type Ident = String;\n') f.write('\n') StructVisitor(f, typeinfo).emit_attrs(0) f.write('pub struct Located {\n') f.write(' pub location: Location,\n') f.write(' pub custom: U,\n') f.write(' pub node: T,\n') f.write('}\n') f.write('\n') f.write('impl Located {\n') f.write(' pub fn new(location: Location, node: T) -> Self {\n') f.write(' Self { location, custom: (), node }\n') f.write(' }\n') f.write('}\n') f.write('\n') c = ChainOfVisitors(StructVisitor(f, typeinfo), FoldModuleVisitor(f, typeinfo)) c.visit(mod) def write_ast_mod(mod, f): f.write('use super::*;\n') f.write('\n') c = ChainOfVisitors(ClassDefVisitor(f), TraitImplVisitor(f), ExtendModuleVisitor(f)) c.visit(mod) def main(input_filename, ast_mod_filename, ast_def_filename, dump_module=False): auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:])) mod = asdl.parse(input_filename) if dump_module: print('Parsed Module:') print(mod) if not asdl.check(mod): sys.exit(1) typeinfo = {} FindUserdataTypesVisitor(typeinfo).visit(mod) with ast_def_filename.open("w") as def_file, \ ast_mod_filename.open("w") as mod_file: def_file.write(auto_gen_msg) write_ast_def(mod, typeinfo, def_file) mod_file.write(auto_gen_msg) write_ast_mod(mod, mod_file) print(f"{ast_def_filename}, {ast_mod_filename} regenerated.") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("input_file", type=Path) parser.add_argument("-M", "--mod-file", type=Path, required=True) parser.add_argument("-D", "--def-file", type=Path, required=True) parser.add_argument("-d", "--dump-module", action="store_true") args = parser.parse_args() main(args.input_file, args.mod_file, args.def_file, args.dump_module)