forked from M-Labs/nac3
implement simple expressions and return statement
This commit is contained in:
parent
69ed23ee00
commit
befd01b240
80
src/main.rs
80
src/main.rs
|
@ -10,16 +10,19 @@ use inkwell::module::Module;
|
|||
use inkwell::targets::*;
|
||||
use inkwell::types;
|
||||
use inkwell::types::BasicType;
|
||||
use inkwell::values;
|
||||
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::path::Path;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CompileErrorKind {
|
||||
Unsupported(&'static str),
|
||||
MissingTypeAnnotation,
|
||||
UnknownTypeAnnotation,
|
||||
IncompatibleTypes,
|
||||
Internal(&'static str)
|
||||
}
|
||||
|
||||
|
@ -32,6 +35,8 @@ impl fmt::Display for CompileErrorKind {
|
|||
=> write!(f, "Missing type annotation"),
|
||||
CompileErrorKind::UnknownTypeAnnotation
|
||||
=> write!(f, "Unknown type annotation"),
|
||||
CompileErrorKind::IncompatibleTypes
|
||||
=> write!(f, "Incompatible types"),
|
||||
CompileErrorKind::Internal(details)
|
||||
=> write!(f, "Internal compiler error: {}", details),
|
||||
}
|
||||
|
@ -59,6 +64,7 @@ struct CodeGen<'ctx> {
|
|||
module: Module<'ctx>,
|
||||
builder: Builder<'ctx>,
|
||||
current_source_location: ast::Location,
|
||||
namespace: HashMap<String, values::BasicValueEnum<'ctx>>,
|
||||
}
|
||||
|
||||
impl<'ctx> CodeGen<'ctx> {
|
||||
|
@ -68,6 +74,7 @@ impl<'ctx> CodeGen<'ctx> {
|
|||
module: context.create_module("kernel"),
|
||||
builder: context.create_builder(),
|
||||
current_source_location: ast::Location::default(),
|
||||
namespace: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -147,21 +154,82 @@ impl<'ctx> CodeGen<'ctx> {
|
|||
let function = self.module.add_function(name, fn_type, None);
|
||||
let basic_block = self.context.append_basic_block(function, "entry");
|
||||
self.builder.position_at_end(basic_block);
|
||||
let x = function.get_nth_param(0).unwrap().into_int_value();
|
||||
let y = function.get_nth_param(1).unwrap().into_int_value();
|
||||
let sum = self.builder.build_int_add(x, y, "sum");
|
||||
self.builder.build_return(Some(&sum));
|
||||
|
||||
for (n, arg) in args.args.iter().enumerate() {
|
||||
self.namespace.insert(arg.arg.clone(), function.get_nth_param(n as u32).unwrap());
|
||||
}
|
||||
|
||||
for statement in body.iter() {
|
||||
self.compile_statement(statement)?;
|
||||
self.compile_statement(statement, return_type)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compile_statement(&mut self, statement: &ast::Statement) -> CompileResult<()> {
|
||||
fn compile_expression(
|
||||
&mut self,
|
||||
expression: &ast::Expression
|
||||
) -> CompileResult<values::BasicValueEnum<'ctx>> {
|
||||
self.set_source_location(expression.location);
|
||||
|
||||
use ast::ExpressionType::*;
|
||||
match &expression.node {
|
||||
Identifier { name } => {
|
||||
Ok(*self.namespace.get(name).unwrap())
|
||||
},
|
||||
Binop { a, op, b } => {
|
||||
let a = self.compile_expression(&a)?;
|
||||
let b = self.compile_expression(&b)?;
|
||||
if a.get_type() != b.get_type() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
use ast::Operator::*;
|
||||
match (op, a, b) {
|
||||
(Add, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_add(a, b, "tmpadd").into()),
|
||||
(Sub, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_sub(a, b, "tmpsub").into()),
|
||||
(Mult, values::BasicValueEnum::IntValue(a), values::BasicValueEnum::IntValue(b))
|
||||
=> Ok(self.builder.build_int_mul(a, b, "tmpmul").into()),
|
||||
|
||||
(Add, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_add(a, b, "tmpadd").into()),
|
||||
(Sub, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_sub(a, b, "tmpsub").into()),
|
||||
(Mult, values::BasicValueEnum::FloatValue(a), values::BasicValueEnum::FloatValue(b))
|
||||
=> Ok(self.builder.build_float_mul(a, b, "tmpmul").into()),
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented operation"))),
|
||||
}
|
||||
}
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("unimplemented expression"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_statement(
|
||||
&mut self,
|
||||
statement: &ast::Statement,
|
||||
return_type: Option<types::BasicTypeEnum>
|
||||
) -> CompileResult<()> {
|
||||
self.set_source_location(statement.location);
|
||||
|
||||
use ast::StatementType::*;
|
||||
match &statement.node {
|
||||
Return { value: Some(value) } => {
|
||||
if let Some(return_type) = return_type {
|
||||
let value = self.compile_expression(value)?;
|
||||
if value.get_type() != return_type {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
self.builder.build_return(Some(&value));
|
||||
} else {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
},
|
||||
Return { value: None } => {
|
||||
if !return_type.is_none() {
|
||||
return Err(self.compile_error(CompileErrorKind::IncompatibleTypes));
|
||||
}
|
||||
self.builder.build_return(None);
|
||||
},
|
||||
Pass => (),
|
||||
_ => return Err(self.compile_error(CompileErrorKind::Unsupported("special statement"))),
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue