From befd01b240963057e2d41becf17edc9b4be8523e Mon Sep 17 00:00:00 2001 From: Sebastien Bourdeauducq Date: Sun, 29 Mar 2020 12:12:08 +0800 Subject: [PATCH] implement simple expressions and return statement --- src/main.rs | 80 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/src/main.rs b/src/main.rs index 8caf381..daa3233 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,16 +10,19 @@ use inkwell::module::Module; use inkwell::targets::*; use inkwell::types; use inkwell::types::BasicType; +use inkwell::values; use std::error::Error; use std::fmt; use std::path::Path; +use std::collections::HashMap; #[derive(Debug)] enum CompileErrorKind { Unsupported(&'static str), MissingTypeAnnotation, UnknownTypeAnnotation, + IncompatibleTypes, Internal(&'static str) } @@ -32,6 +35,8 @@ impl fmt::Display for CompileErrorKind { => write!(f, "Missing type annotation"), CompileErrorKind::UnknownTypeAnnotation => write!(f, "Unknown type annotation"), + CompileErrorKind::IncompatibleTypes + => write!(f, "Incompatible types"), CompileErrorKind::Internal(details) => write!(f, "Internal compiler error: {}", details), } @@ -59,6 +64,7 @@ struct CodeGen<'ctx> { module: Module<'ctx>, builder: Builder<'ctx>, current_source_location: ast::Location, + namespace: HashMap>, } impl<'ctx> CodeGen<'ctx> { @@ -68,6 +74,7 @@ impl<'ctx> CodeGen<'ctx> { module: context.create_module("kernel"), builder: context.create_builder(), current_source_location: ast::Location::default(), + namespace: HashMap::new(), } } @@ -147,21 +154,82 @@ impl<'ctx> CodeGen<'ctx> { let function = self.module.add_function(name, fn_type, None); let basic_block = self.context.append_basic_block(function, "entry"); self.builder.position_at_end(basic_block); - let x = function.get_nth_param(0).unwrap().into_int_value(); - let y = function.get_nth_param(1).unwrap().into_int_value(); - let sum = self.builder.build_int_add(x, y, "sum"); - self.builder.build_return(Some(&sum)); + + for (n, arg) in args.args.iter().enumerate() { + self.namespace.insert(arg.arg.clone(), function.get_nth_param(n as u32).unwrap()); + } + for statement in body.iter() { - self.compile_statement(statement)?; + self.compile_statement(statement, return_type)?; } Ok(()) } - fn compile_statement(&mut self, statement: &ast::Statement) -> CompileResult<()> { + fn compile_expression( + &mut self, + expression: &ast::Expression + ) -> CompileResult> { + self.set_source_location(expression.location); + + use ast::ExpressionType::*; + match &expression.node { + Identifier { name } => { + Ok(*self.namespace.get(name).unwrap()) + }, + Binop { a, op, b } => { + let a = self.compile_expression(&a)?; + let b = self.compile_expression(&b)?; + if a.get_type() != b.get_type() { + return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); + } + use ast::Operator::*; + match (op, a, b) { + (Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) + => Ok(self.builder.build_int_add(a, b, "tmpadd").into()), + (Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) + => Ok(self.builder.build_int_sub(a, b, "tmpsub").into()), + (Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b)) + => Ok(self.builder.build_int_mul(a, b, "tmpmul").into()), + + (Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) + => Ok(self.builder.build_float_add(a, b, "tmpadd").into()), + (Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) + => Ok(self.builder.build_float_sub(a, b, "tmpsub").into()), + (Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b)) + => Ok(self.builder.build_float_mul(a, b, "tmpmul").into()), + _ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented operation"))), + } + } + _ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))), + } + } + + fn compile_statement( + &mut self, + statement: &ast::Statement, + return_type: Option + ) -> CompileResult<()> { self.set_source_location(statement.location); use ast::StatementType::*; match &statement.node { + Return { value: Some(value) } => { + if let Some(return_type) = return_type { + let value = self.compile_expression(value)?; + if value.get_type() != return_type { + return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); + } + self.builder.build_return(Some(&value)); + } else { + return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); + } + }, + Return { value: None } => { + if !return_type.is_none() { + return Err(self.compile_error(CompileErrorKind::IncompatibleTypes)); + } + self.builder.build_return(None); + }, Pass => (), _ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))), }