Update Chapter 3 to demonstrate pattern match and rewrite optimizations

This is using Table-driven Declarative Rewrite Rules (DRR), the previous
version of the tutorial only showed the C++ patterns.

Closes tensorflow/mlir#187

PiperOrigin-RevId: 274852321
This commit is contained in:
Sana Damani
2019-10-15 11:40:12 -07:00
committed by A. Unique TensorFlower
parent 4e85dafedd
commit cd45b0c8d9
16 changed files with 1049 additions and 1189 deletions

View File

@@ -2,16 +2,33 @@ set(LLVM_LINK_COMPONENTS
Support
)
set(LLVM_TARGET_DEFINITIONS include/toy/Ops.td)
mlir_tablegen(include/toy/Ops.h.inc -gen-op-decls)
mlir_tablegen(include/toy/Ops.cpp.inc -gen-op-defs)
add_public_tablegen_target(ToyCh3OpsIncGen)
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
add_public_tablegen_target(ToyCh3CombineIncGen)
add_toy_chapter(toyc-ch3
toyc.cpp
parser/AST.cpp
mlir/MLIRGen.cpp
mlir/ToyDialect.cpp
mlir/Dialect.cpp
mlir/ToyCombine.cpp
)
add_dependencies(toyc-ch3 ToyCh3OpsIncGen)
add_dependencies(toyc-ch3 ToyCh3CombineIncGen)
include_directories(include/)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
target_link_libraries(toyc-ch3
PRIVATE
MLIRAnalysis
MLIRIR
MLIRParser
MLIRPass
MLIRTransforms)

View File

@@ -33,10 +33,9 @@
namespace toy {
/// A variable
/// A variable type with shape information.
struct VarType {
enum { TY_FLOAT, TY_INT } elt_ty;
std::vector<int> shape;
std::vector<int64_t> shape;
};
/// Base class for all expression nodes.
@@ -50,9 +49,7 @@ public:
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print, // builtin
Expr_If,
Expr_For,
Expr_Print,
};
ExprAST(ExprASTKind kind, Location location)
@@ -85,7 +82,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
};
///
/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
@@ -116,7 +113,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
};
///
/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
@@ -136,7 +133,7 @@ public:
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
};
///
/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;

View File

@@ -16,7 +16,7 @@
// =============================================================================
//
// This file implements the IR Dialect for the Toy language.
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
//
//===----------------------------------------------------------------------===//
@@ -25,311 +25,29 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"
namespace mlir {
class Builder;
}
namespace toy {
/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and register custom operations and types (in its constructor).
/// It can also overriding general behavior of dialects exposed as virtual
/// method, for example regarding verification and parsing/printing.
/// mlir::Dialect and registers custom attributes, operations, and types (in its
/// constructor). It can also override some general behavior exposed via virtual
/// methods.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);
/// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType(llvm::StringRef tyData,
mlir::Location loc) const override;
/// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
/// Provide a utility accessor to the dialect namespace. This is used by
/// several utilities for casting between dialects.
static llvm::StringRef getDialectNamespace() { return "toy"; }
};
////////////////////////////////////////////////////////////////////////////////
/////////////////////// Custom Types for the Dialect ///////////////////////////
////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct ToyArrayTypeStorage;
}
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
enum ToyTypeKind {
// The enum starts at the range reserved for this dialect.
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
TOY_ARRAY,
};
/// Type for Toy arrays.
/// In MLIR Types are reference to immutable and uniqued objects owned by the
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
/// provides the public facade API to interact with the type.
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
detail::ToyArrayTypeStorage> {
public:
using Base::Base;
/// Returns the dimensions for this array, or and empty range for a generic
/// array.
llvm::ArrayRef<int64_t> getShape();
/// Predicate to test if this array is generic (shape haven't been inferred
/// yet).
bool isGeneric() { return getShape().empty(); }
/// Return the rank of this array (0 if it is generic).
int getRank() { return getShape().size(); }
/// Return the type of individual elements in the array.
mlir::Type getElementType();
/// Get the unique instance of this Type from the context.
/// A ToyArrayType is only defined by the shape of the array.
static ToyArrayType get(mlir::MLIRContext *context,
llvm::ArrayRef<int64_t> shape = {});
/// Support method to enable LLVM-style RTTI type casting.
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
};
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Constant operation turns a literal into an SSA value. The data is attached
/// to the operation as an attribute. For example:
///
/// %0 = "toy.constant"()
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
/// : () -> !toy.array<2, 3>
///
/// An operation inherits from `class Op` and specifies optional traits. Here we
/// indicate that `toy.constant` does not have any operands and returns a single
/// result. The traits provide some utilities methods for the operation, for
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
/// available.
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
public:
/// This is the name used by MLIR to match an operation to this class during
/// parsing.
static llvm::StringRef getOperationName() { return "toy.constant"; }
/// The operation can have extra verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populates the `state` that MLIR uses to create operations.
/// The `toy.constant` operation does not have arguments but attaches a
/// constant array as an attribute and returns it as an SSA value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::ArrayRef<int64_t> shape,
mlir::DenseElementsAttr value);
/// Similar to the one above, but takes a single float and returns a
/// !toy.array<1>.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value);
/// Inherit constructor.
using Op::Op;
};
/// Generic calls represent calls to a user defined function that needs to
/// be specialized for the shape of its arguments. The callee name is attached
/// as a literal string as an attribute. The arguments list must match the
/// arguments expected by the callee. For example:
///
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
///
/// This is only valid if a function named "my_func" exists and takes two
/// arguments.
class GenericCallOp
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult> {
public:
/// MLIR will use this to register the operation with the parser/printer.
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to the builder to allow:
/// mlir::Builder::create<GenericCallOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.generic_call` operation accepts a callee name and a list of
/// arguments for the call.
static void build(mlir::Builder *builder, mlir::OperationState &state,
llvm::StringRef callee,
llvm::ArrayRef<mlir::Value *> arguments);
/// Return the name of the callee.
llvm::StringRef getCalleeName();
/// Inherit constructor.
using Op::Op;
};
/// Return operations terminate blocks (and functions as well). They take a
/// single argument and the type must match the function return type.
class ReturnOp
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
public:
static llvm::StringRef getOperationName() { return "toy.return"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.return` operation accepts an optional single array as an argument
/// and does not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value = nullptr);
/// Return true if there is a returned value.
bool hasOperand() { return 0 != getNumOperands(); }
/// Helper to return the optional operand. Caller must check if the operand
/// is present before calling this.
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
/// Inherit constructor.
using Op::Op;
};
/// The print builtin takes a single array argument and does not return any.
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::ZeroResult> {
public:
static llvm::StringRef getOperationName() { return "toy.print"; }
/// Operations can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.print` operation accepts a single array as argument and does
/// not have any returned value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
/// Inherit constructor.
using Op::Op;
};
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult> {
public:
static llvm::StringRef getOperationName() { return "toy.transpose"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<TransposeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.transpose` operation accepts a single array as argument and
/// returns the transposed array as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value);
/// Inherit constructor.
using Op::Op;
};
/// Reshape operation is transforming its input array into a new array with the
/// same number of elements but different shapes. For example:
///
/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2>
///
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
mlir::OpTrait::OneResult> {
public:
static llvm::StringRef getOperationName() { return "toy.reshape"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<ReshapeOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.reshape` operation accepts a single array as argument and
/// returns the array with the specified reshapedType as its only result.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType);
/// Inherit constructor.
using Op::Op;
};
/// Binary operation implementing a multiplication. For two-dimensional array
/// a matrix multiplication is implemented, while for one dimensional array a
/// dot product is performed.
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult> {
public:
static llvm::StringRef getOperationName() { return "toy.mul"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Element wise addition of two arrays. The shape must match.
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
mlir::OpTrait::OneResult> {
public:
static llvm::StringRef getOperationName() { return "toy.add"; }
/// Operation can add custom verification beyond the traits they define.
mlir::LogicalResult verify();
/// Interface to mlir::Builder::create<PrintOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.mul` operation accepts two operands as argument and returns
/// a single value.
static void build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs);
/// Convenience accessor for LHS of the expression.
mlir::Value *getLHS() { return getOperand(0); }
/// Convenience accessor for RHS of the expression.
mlir::Value *getRHS() { return getOperand(1); }
/// Inherit constructor.
using Op::Op;
};
/// Include the auto-generated header file containing the declarations of the
/// toy operations.
#define GET_OP_CLASSES
#include "toy/Ops.h.inc"
} // end namespace toy
} // end namespace mlir
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_

View File

@@ -31,7 +31,7 @@ namespace toy {
/// Structure definition a location in a file.
struct Location {
std::shared_ptr<std::string> file; ///< filename
std::shared_ptr<std::string> file; ///< filename.
int line; ///< line number.
int col; ///< column number.
};

View File

@@ -0,0 +1,247 @@
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines the operations of the Toy dialect.
//
//===----------------------------------------------------------------------===//
#ifdef TOY_OPS
#else
#define TOY_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// Provide a definition of the 'toy' dialect in the ODS framework so that we
// can define our operations.
def Toy_Dialect : Dialect {
let name = "toy";
let cppNamespace = "toy";
}
// Base class for toy dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
// * The parent dialect of the operation.
// * The mnemonic for the operation, or the name without the dialect prefix.
// * A list of traits for the operation.
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
Op<Toy_Dialect, mnemonic, traits>;
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
// We define a toy operation by inherting from our base 'Toy_Op' class above.
// Here we provide the mnemonic and a list of traits for the operation. The
// constant operation is marked as 'NoSideEffect' as it is a pure operation
// and may be removed if dead.
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
// Provide a summary and description for this operation. This can be used to
// auto-generate documenatation of the operations within our dialect.
let summary = "constant";
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:
```mlir
%0 = "toy.constant"()
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
: () -> tensor<2x3xf64>
```
}];
// The constant operation takes an attribute as the only input.
let arguments = (ins F64ElementsAttr:$value);
// The constant operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the constant operation. These method populates
// the `state` that MLIR uses to create operations, i.e. these are used when
// using `builder.create<ConstantOp>(...)`.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"DenseElementsAttr value", [{
build(builder, result, value.getType(), value);
}]>,
// Build a constant with a given constant floating-point value.
OpBuilder<"Builder *builder, OperationState &result, double value", [{
buildConstantOp(builder, result, value);
}]>
];
// Invoke a static verify method to verify this constant operation.
let verifier = [{ return ::verify(*this); }];
}
def AddOp : Toy_Op<"add", [NoSideEffect]> {
let summary = "element-wise addition operation";
let description = [{
The "add" operation performs element-wise addition between two tensors.
The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building an AddOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildAddOp(b, result, lhs, rhs);
}]
>];
}
def GenericCallOp : Toy_Op<"generic_call"> {
let summary = "generic call operation";
let description = [{
Generic calls represent calls to a user defined function that needs to
be specialized for the shape of its arguments. The callee name is attached
as a symbol reference via an attribute. The arguments list must match the
arguments expected by the callee. For example:
```mlir
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
```
This is only valid if a function named "my_func" exists and takes two
arguments.
}];
// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
// Add custom build methods for the generic call operation.
let builders = [
// Build a constant with a given constant tensor value.
OpBuilder<"Builder *builder, OperationState &result, "
"StringRef callee, ArrayRef<Value *> arguments", [{
buildGenericCallOp(builder, result, callee, arguments);
}]>
];
}
def MulOp : Toy_Op<"mul", [NoSideEffect]> {
let summary = "element-wise multiplication operation";
let description = [{
The "mul" operation performs element-wise multiplication between two
tensors. The shapes of the tensor operands are expected to match.
}];
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
let results = (outs F64Tensor);
// Allow building a MulOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
buildMulOp(b, result, lhs, rhs);
}]
>];
}
def PrintOp : Toy_Op<"print"> {
let summary = "print operation";
let description = [{
The "print" builtin operation prints a given input tensor, and produces
no results.
}];
// The print operation takes an input tensor to print.
let arguments = (ins F64Tensor:$input);
}
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
let summary = "tensor reshape operation";
let description = [{
Reshape operation is transforming its input tensor into a new tensor with
the same number of elements but different shapes. For example:
```mlir
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
```
}];
let arguments = (ins F64Tensor:$input);
// Enabled registering canonicalization patterns with this operation.
let hasCanonicalizer = 1;
// We expect that the reshape operation returns a statically shaped tensor.
let results = (outs StaticShapeTensorOf<[F64]>);
}
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
let summary = "return operation";
let description = [{
The "return" operation represents a return operation within a function.
The operation takes an optional tensor operand and produces no results.
The operand type must match the signature of the function that contains
the operation. For example:
```mlir
func @foo() -> tensor<2xf64> {
...
toy.return %0 : tensor<2xf64>
}
```
}];
// The return operation takes an optional input operand to return. This
// value must match the return type of the enclosing function.
let arguments = (ins Variadic<F64Tensor>:$input);
// Allow building a ReturnOp with no return operand.
let builders = [OpBuilder<
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
>];
// Provide extra utility definitions on the c++ operation class definition.
let extraClassDeclaration = [{
bool hasOperand() { return getNumOperands() != 0; }
}];
// Invoke a static verify method to verify this return operation.
let verifier = [{ return ::verify(*this); }];
}
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
let summary = "transpose operation";
let arguments = (ins F64Tensor:$input);
let results = (outs F64Tensor);
// Enabled registering canonicalization patterns with this operation.
let hasCanonicalizer = 1;
// Allow building a TransposeOp with from the two input operands.
let builders = [
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
buildTransposeOp(b, result, input);
}]
>];
}
#endif // TOY_OPS

View File

@@ -0,0 +1,151 @@
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace mlir::toy;
//===----------------------------------------------------------------------===//
// ToyDialect
//===----------------------------------------------------------------------===//
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<
#define GET_OP_LIST
#include "toy/Ops.cpp.inc"
>();
}
//===----------------------------------------------------------------------===//
// Toy Operations
//===----------------------------------------------------------------------===//
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
double value) {
auto dataType = builder->getTensorType({}, builder->getF64Type());
auto dataAttribute = DenseElementsAttr::get(dataType, value);
ConstantOp::build(builder, state, dataType, dataAttribute);
}
/// Verifier for constant operation.
static mlir::LogicalResult verify(ConstantOp op) {
// If the return type of the constant is not an unranked tensor, the shape
// must match the shape of the attribute holding the data.
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
if (!resultType)
return success();
auto attrType = op.value().getType().cast<mlir::TensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op.emitOpError(
"return type must match the one of the attached value "
"attribute: ")
<< attrType.getRank() << " != " << resultType.getRank();
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
return op.emitOpError(
"return type shape mismatches its attribute at dimension ")
<< dim << ": " << attrType.getShape()[dim]
<< " != " << resultType.getShape()[dim];
}
}
return mlir::success();
}
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static void buildGenericCallOp(mlir::Builder *builder,
mlir::OperationState &state, StringRef callee,
ArrayRef<mlir::Value *> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
}
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands({lhs, rhs});
}
static mlir::LogicalResult verify(ReturnOp op) {
// We know that the parent operation is a function, because of the 'HasParent'
// trait attached to the operation definition.
auto function = cast<FuncOp>(op.getParentOp());
/// ReturnOps can only have a single optional operand.
if (op.getNumOperands() > 1)
return op.emitOpError() << "expects at most 1 return operand";
// The operand number and types must match the function signature.
const auto &results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError()
<< "does not return the same number of values ("
<< op.getNumOperands() << ") as the enclosing function ("
<< results.size() << ")";
// If the operation does not have an input, we are done.
if (!op.hasOperand())
return mlir::success();
auto inputType = *op.operand_type_begin();
auto resultType = results.front();
// Check that the result type of the function matches the operand type.
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
resultType.isa<mlir::UnrankedTensorType>())
return mlir::success();
return op.emitError() << "type of return operand ("
<< *op.operand_type_begin()
<< ") doesn't match function result type ("
<< results.front() << ")";
}
static void buildTransposeOp(mlir::Builder *builder,
mlir::OperationState &state, mlir::Value *value) {
state.addTypes(builder->getTensorType(builder->getF64Type()));
state.addOperands(value);
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "toy/Ops.cpp.inc"

View File

@@ -25,30 +25,30 @@
#include "toy/Dialect.h"
#include "mlir/Analysis/Verifier.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
using namespace mlir::toy;
using namespace toy;
using llvm::ArrayRef;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
using llvm::makeArrayRef;
using llvm::ScopedHashTableScope;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
using std::make_unique;
namespace {
@@ -57,56 +57,43 @@ namespace {
/// This will emit operations that are specific to the Toy language, preserving
/// the semantics of the language and (hopefully) allow to perform accurate
/// analysis and transformation based on these high level semantics.
///
/// At this point we take advantage of the "raw" MLIR APIs to create operations
/// that haven't been registered in any way with MLIR. These operations are
/// unknown to MLIR, custom passes could operate by string-matching the name of
/// these operations, but no other type checking or semantic is associated with
/// them natively by MLIR.
class MLIRGenImpl {
public:
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
/// Module operation.
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
if (!func)
return nullptr;
theModule->push_back(func);
theModule.push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
// this won't do much, but it should at least check some structural
// properties.
if (failed(mlir::verify(*theModule))) {
emitError(mlir::UnknownLoc::get(&context), "module verification error");
// Verify the module after we have finished constructing it, this will check
// the structural properties of the IR and invoke any specific verifiers we
// have on the Toy operations.
if (failed(mlir::verify(theModule))) {
theModule.emitError("module verification error");
return nullptr;
}
return std::move(theModule);
return theModule;
}
private:
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
/// the ownership of many internal structure of the IR and provide a level
/// of "uniquing" across multiple modules (types for instance).
mlir::MLIRContext &context;
/// A "module" matches a Toy source file: containing a list of functions.
mlir::ModuleOp theModule;
/// A "module" matches a source file: it contains a list of functions.
mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
/// convenience for emitting individual operations.
/// The builder is stateful, in particular it keeeps an "insertion point":
/// this is where the next operations will be introduced.
std::unique_ptr<mlir::OpBuilder> builder;
/// The builder is a helper class to create IR inside a function. The builder
/// is stateful, in particular it keeeps an "insertion point": this is where
/// the next operations will be introduced.
mlir::OpBuilder builder;
/// The symbol table maps a variable name to a value in the current scope.
/// Entering a function creates a new scope, and the function arguments are
@@ -116,37 +103,35 @@ private:
/// Helper conversion for a Toy AST location to an MLIR location.
mlir::Location loc(Location loc) {
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
loc.line, loc.col, &context);
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
loc.col);
}
/// Declare a variable in the current scope, return true if the variable
/// Declare a variable in the current scope, return success if the variable
/// wasn't declared yet.
bool declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var)) {
return false;
}
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
if (symbolTable.count(var))
return mlir::failure();
symbolTable.insert(var, value);
return true;
return mlir::success();
}
/// Create the prototype for an MLIR function with as many arguments as the
/// provided Toy AST prototype.
mlir::FuncOp mlirGen(PrototypeAST &proto) {
auto location = loc(proto.loc());
// This is a generic function, the return type will be inferred later.
llvm::SmallVector<mlir::Type, 4> ret_types;
// Arguments type is uniformly a generic array.
// Arguments type are uniformly unranked tensors.
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
getType(VarType{}));
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
func_type, /* attrs = */ {});
auto func_type = builder.getFunctionType(arg_types, llvm::None);
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
// Mark the function as generic: it'll require type specialization for every
// call site.
if (function.getNumArguments())
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
function.setAttr("toy.generic", builder.getUnitAttr());
return function;
}
@@ -165,18 +150,22 @@ private:
// argument list as the function itself.
auto &entryBlock = *function.addEntryBlock();
auto &protoArgs = funcAST.getProto()->getArgs();
// Declare all the function arguments in the symbol table.
for (const auto &name_value :
llvm::zip(protoArgs, entryBlock.getArguments())) {
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
if (failed(declare(std::get<0>(name_value)->getName(),
std::get<1>(name_value))))
return nullptr;
}
// Create a builder for the function, it will be used throughout the codegen
// to create operations in this function.
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
// Set the insertion point in the builder to the beginning of the function
// body, it will be used throughout the codegen to create operations in this
// function.
builder.setInsertionPointToStart(&entryBlock);
// Emit the body of the function.
if (!mlirGen(*funcAST.getBody())) {
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
function.erase();
return nullptr;
}
@@ -184,10 +173,16 @@ private:
// Implicitly return void if no return statement was emitted.
// FIXME: we may fix the parser instead to always return the last expression
// (this would possibly help the REPL case later)
if (function.getBlocks().back().back().getName().getStringRef() !=
"toy.return") {
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
mlirGen(fakeRet);
ReturnOp returnOp;
if (!entryBlock.empty())
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
if (!returnOp) {
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
} else if (returnOp.hasOperand()) {
// Otherwise, if this return operation has an operand then add a result to
// the function.
function.setType(builder.getFunctionType(function.getType().getInputs(),
getType(VarType{})));
}
return function;
@@ -206,11 +201,11 @@ private:
// and the result value is returned. If an error occurs we get a nullptr
// and propagate.
//
mlir::Value *L = mlirGen(*binop.getLHS());
if (!L)
mlir::Value *lhs = mlirGen(*binop.getLHS());
if (!lhs)
return nullptr;
mlir::Value *R = mlirGen(*binop.getRHS());
if (!R)
mlir::Value *rhs = mlirGen(*binop.getRHS());
if (!rhs)
return nullptr;
auto location = loc(binop.loc());
@@ -218,123 +213,113 @@ private:
// support '+' and '*'.
switch (binop.getOp()) {
case '+':
return builder->create<AddOp>(location, L, R).getResult();
break;
return builder.create<AddOp>(location, lhs, rhs);
case '*':
return builder->create<MulOp>(location, L, R).getResult();
default:
emitError(loc(binop.loc()), "error: invalid binary operator '")
<< binop.getOp() << "'";
return nullptr;
return builder.create<MulOp>(location, lhs, rhs);
}
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
return nullptr;
}
// This is a reference to a variable in an expression. The variable is
// expected to have been declared and so should have a value in the symbol
// table, otherwise emit an error and return nullptr.
/// This is a reference to a variable in an expression. The variable is
/// expected to have been declared and so should have a value in the symbol
/// table, otherwise emit an error and return nullptr.
mlir::Value *mlirGen(VariableExprAST &expr) {
if (symbolTable.count(expr.getName()))
return symbolTable.lookup(expr.getName());
if (auto *variable = symbolTable.lookup(expr.getName()))
return variable;
emitError(loc(expr.loc()), "error: unknown variable '")
<< expr.getName() << "'";
return nullptr;
}
// Emit a return operation, return true on success.
bool mlirGen(ReturnExprAST &ret) {
/// Emit a return operation. This will return failure if any generation fails.
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
auto location = loc(ret.loc());
// `return` takes an optional expression, we need to account for it here.
if (!ret.getExpr().hasValue()) {
builder->create<ReturnOp>(location);
return true;
// 'return' takes an optional expression, handle that case here.
mlir::Value *expr = nullptr;
if (ret.getExpr().hasValue()) {
if (!(expr = mlirGen(*ret.getExpr().getValue())))
return mlir::failure();
}
auto *expr = mlirGen(*ret.getExpr().getValue());
if (!expr)
return false;
builder->create<ReturnOp>(location, expr);
return true;
// Otherwise, this return operation has zero operands.
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
: ArrayRef<mlir::Value *>());
return mlir::success();
}
// Emit a literal/constant array. It will be emitted as a flattened array of
// data in an Attribute attached to a `toy.constant` operation.
// See documentation on [Attributes](LangRef.md#attributes) for more details.
// Here is an excerpt:
//
// Attributes are the mechanism for specifying constant data in MLIR in
// places where a variable is never allowed [...]. They consist of a name
// and a [concrete attribute value](#attribute-values). It is possible to
// attach attributes to operations, functions, and function arguments. The
// set of expected attributes, their structure, and their interpretation
// are all contextually dependent on what they are attached to.
//
// Example, the source level statement:
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
// will be converted to:
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
//
/// Emit a literal/constant array. It will be emitted as a flattened array of
/// data in an Attribute attached to a `toy.constant` operation.
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
/// Here is an excerpt:
///
/// Attributes are the mechanism for specifying constant data in MLIR in
/// places where a variable is never allowed [...]. They consist of a name
/// and a concrete attribute value. The set of expected attributes, their
/// structure, and their interpretation are all contextually dependent on
/// what they are attached to.
///
/// Example, the source level statement:
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
/// will be converted to:
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
///
mlir::Value *mlirGen(LiteralExprAST &lit) {
auto location = loc(lit.loc());
// The attribute is a vector with an attribute per element (number) in the
// array, see `collectData()` below for more details.
std::vector<mlir::Attribute> data;
auto type = getType(lit.getDims());
// The attribute is a vector with a floating point value per element
// (number) in the array, see `collectData()` below for more details.
std::vector<double> data;
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
std::multiplies<int>()));
collectData(lit, data);
// FIXME: using a tensor type is a HACK here.
// Can we do differently without registering a dialect? Using a string blob?
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto dataType = builder->getTensorType(lit.getDims(), elementType);
// The type of this attribute is tensor of 64-bit floating-point with the
// shape of the literal.
mlir::Type elementType = builder.getF64Type();
auto dataType = builder.getTensorType(lit.getDims(), elementType);
// This is the actual attribute that actually hold the list of values for
// this array literal.
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
.cast<mlir::DenseElementsAttr>();
// This is the actual attribute that holds the list of values for this
// tensor literal.
auto dataAttribute =
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
// Build the MLIR op `toy.constant`, only boilerplate below.
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
.getResult();
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
// method.
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
}
// Recursive helper function to accumulate the data that compose an array
// literal. It flattens the nested structure in the supplied vector. For
// example with this array:
// [[1, 2], [3, 4]]
// we will generate:
// [ 1, 2, 3, 4 ]
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
// Attributes are the way MLIR attaches constant to operations and functions.
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
/// Recursive helper function to accumulate the data that compose an array
/// literal. It flattens the nested structure in the supplied vector. For
/// example with this array:
/// [[1, 2], [3, 4]]
/// we will generate:
/// [ 1, 2, 3, 4 ]
/// Individual numbers are represented as doubles.
/// Attributes are the way MLIR attaches constant to operations.
void collectData(ExprAST &expr, std::vector<double> &data) {
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
for (auto &value : lit->getValues())
collectData(*value, data);
return;
}
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
data.push_back(attr);
data.push_back(cast<NumberExprAST>(expr).getValue());
}
// Emit a call expression. It emits specific operations for the `transpose`
// builtin. Other identifiers are assumed to be user-defined functions.
/// Emit a call expression. It emits specific operations for the `transpose`
/// builtin. Other identifiers are assumed to be user-defined functions.
mlir::Value *mlirGen(CallExprAST &call) {
llvm::StringRef callee = call.getCallee();
auto location = loc(call.loc());
std::string callee = call.getCallee();
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
return builder->create<TransposeOp>(location, arg).getResult();
}
// Codegen the operands first
// Codegen the operands first.
SmallVector<mlir::Value *, 4> operands;
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
@@ -342,34 +327,41 @@ private:
return nullptr;
operands.push_back(arg);
}
// Calls to user-defined function are mapped to a custom call that takes
// the callee name as an attribute.
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
.getResult();
// Builting calls have their custom operation, meaning this is a
// straightforward emission.
if (callee == "transpose") {
if (call.getArgs().size() != 1) {
emitError(location, "MLIR codegen encountered an error: toy.transpose "
"does not accept multiple arguments");
return nullptr;
}
return builder.create<TransposeOp>(location, operands[0]);
}
// Otherwise this is a call to a user-defined function. Calls to ser-defined
// functions are mapped to a custom call that takes the callee name as an
// attribute.
return builder.create<GenericCallOp>(location, callee, operands);
}
// Emit a call expression. It emits specific operations for two builtins:
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
// functions. Return false on failure.
bool mlirGen(PrintExprAST &call) {
/// Emit a print expression. It emits specific operations for two builtins:
/// transpose(x) and print(x).
mlir::LogicalResult mlirGen(PrintExprAST &call) {
auto *arg = mlirGen(*call.getArg());
if (!arg)
return false;
auto location = loc(call.loc());
builder->create<PrintOp>(location, arg);
return true;
return mlir::failure();
builder.create<PrintOp>(loc(call.loc()), arg);
return mlir::success();
}
// Emit a constant for a single number (FIXME: semantic? broadcast?)
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
mlir::Value *mlirGen(NumberExprAST &num) {
auto location = loc(num.loc());
mlir::Type elementType = mlir::FloatType::getF64(&context);
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
loc(num.loc()));
return builder->create<ConstantOp>(location, attr).getResult();
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
}
// Dispatch codegen for the right expression subclass using RTTI.
/// Dispatch codegen for the right expression subclass using RTTI.
mlir::Value *mlirGen(ExprAST &expr) {
switch (expr.getKind()) {
case toy::ExprAST::Expr_BinOp:
@@ -390,77 +382,75 @@ private:
}
}
// Handle a variable declaration, we'll codegen the expression that forms the
// initializer and record the value in the symbol table before returning it.
// Future expressions will be able to reference this variable through symbol
// table lookup.
/// Handle a variable declaration, we'll codegen the expression that forms the
/// initializer and record the value in the symbol table before returning it.
/// Future expressions will be able to reference this variable through symbol
/// table lookup.
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
mlir::Value *value = nullptr;
auto location = loc(vardecl.loc());
if (auto init = vardecl.getInitVal()) {
value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder
->create<ReshapeOp>(
location, value,
getType(vardecl.getType()).cast<ToyArrayType>())
.getResult();
}
} else {
auto init = vardecl.getInitVal();
if (!init) {
emitError(loc(vardecl.loc()),
"missing initializer in variable declaration");
return nullptr;
}
// Register the value in the symbol table
declare(vardecl.getName(), value);
mlir::Value *value = mlirGen(*init);
if (!value)
return nullptr;
// We have the initializer value, but in case the variable was declared
// with specific shape, we emit a "reshape" operation. It will get
// optimized out later as needed.
if (!vardecl.getType().shape.empty()) {
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
getType(vardecl.getType()), value);
}
// Register the value in the symbol table.
if (failed(declare(vardecl.getName(), value)))
return nullptr;
return value;
}
/// Codegen a list of expression, return false if one of them hit an error.
bool mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
/// Codegen a list of expression, return failure if one of them hit an error.
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
for (auto &expr : blockAST) {
// Specific handling for variable declarations, return statement, and
// print. These can only appear in block list and not in nested
// expressions.
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
if (!mlirGen(*vardecl))
return false;
return mlir::failure();
continue;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
if (!mlirGen(*ret))
return false;
return true;
}
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
return mlirGen(*ret);
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
if (!mlirGen(*print))
return false;
if (mlir::failed(mlirGen(*print)))
return mlir::success();
continue;
}
// Generic expression dispatch codegen.
if (!mlirGen(*expr))
return false;
return mlir::failure();
}
return true;
return mlir::success();
}
/// Build a type from a list of shape dimensions. Types are `array` followed
/// by an optional dimension list, example: array<2, 2>
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
/// !toy.array<2, 2>
template <typename T> mlir::Type getType(T shape) {
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
return ToyArrayType::get(&context, shape64);
/// Build a tensor type from a list of shape dimensions.
mlir::Type getType(ArrayRef<int64_t> shape) {
// If the shape is empty, then this type is unranked.
if (shape.empty())
return builder.getTensorType(builder.getF64Type());
// Otherwise, we use the given shape.
return builder.getTensorType(shape, builder.getF64Type());
}
/// Build an MLIR type from a Toy AST variable type
/// (forward to the generic getType(T) above).
/// Build an MLIR type from a Toy AST variable type (forward to the generic
/// getType above).
mlir::Type getType(const VarType &type) { return getType(type.shape); }
};

View File

@@ -0,0 +1,75 @@
//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a simple combiner for optimizing pattern in the Toy
// dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "toy/Dialect.h"
#include <numeric>
using namespace mlir;
using namespace toy;
namespace {
/// Include the patterns defined in the Declarative Rewrite framework.
#include "ToyCombine.inc"
} // end anonymous namespace
/// Fold transpose(transpose(x) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method attempts to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
/// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
/// Register our patterns for rewrite by the Canonicalization framework.
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
FoldConstantReshapeOptPattern>(context);
}

View File

@@ -0,0 +1,72 @@
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Defines language-specific pattern match optimizations for Toy using
// Declarative Rewrite Rules (DRR) specified using TableGen records.
//
//===----------------------------------------------------------------------===//
#ifndef TOY_COMBINE
#define TOY_COMBINE
#ifndef OP_BASE
include "toy/Ops.td"
#endif // OP_BASE
/* Pattern-Match and Rewrite using DRR:
class Pattern<
dag sourcePattern, list<dag> resultPatterns,
list<dag> additionalConstraints = [],
dag benefitsAdded = (addBenefit 0)>;
*/
//===----------------------------------------------------------------------===//
// Basic Pattern-Match and Rewrite
//===----------------------------------------------------------------------===//
// Reshape(Reshape(x)) = x
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
(ReshapeOp $arg)>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite using Native Code Call
//===----------------------------------------------------------------------===//
// Native Code Calls may be used for more complex transformations using inline
// C++ and C++ helper functions.
// Reshape(Constant(x)) = x'
def ReshapeConstant :
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
def FoldConstantReshapeOptPattern : Pat<
(ReshapeOp:$res (ConstantOp $arg)),
(ConstantOp (ReshapeConstant $arg, $res))>;
//===----------------------------------------------------------------------===//
// Pattern-Match and Rewrite with Constraints
//===----------------------------------------------------------------------===//
// DRR allows for constraint checking when the transformation is conditional
// on operand properties.
// Reshape(x) = x, where input and output shapes are identical
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
def RedundantReshapeOptPattern : Pat<
(ReshapeOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
#endif // TOY_COMBINE

View File

@@ -1,390 +0,0 @@
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the dialect for the Toy IR: custom type parsing and
// operation verification.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
using llvm::ArrayRef;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
namespace toy {
namespace detail {
/// This class holds the implementation of the ToyArrayType.
/// It is intended to be uniqued based on its content and owned by the context.
struct ToyArrayTypeStorage : public mlir::TypeStorage {
/// This defines how we unique this type in the context: our key contains
/// only the shape, a more complex type would have multiple entries in the
/// tuple here.
/// The element of the tuples usually matches 1-1 the arguments from the
/// public `get()` method arguments from the facade.
using KeyTy = std::tuple<ArrayRef<int64_t>>;
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key));
}
/// When the key hash hits an existing type, we compare the shape themselves
/// to confirm we have the right type.
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
/// This is a factory method to create our type storage. It is only
/// invoked after looking up the type in the context using the key and not
/// finding it.
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
const KeyTy &key) {
// Copy the shape array into the bumpptr allocator owned by the context.
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
// Allocate the instance for the ToyArrayTypeStorage itself
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
// Initialize the instance using placement new.
return new (storage) ToyArrayTypeStorage(shape);
}
ArrayRef<int64_t> getShape() const { return shape; }
private:
ArrayRef<int64_t> shape;
/// Constructor is only invoked from the `construct()` method above.
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
};
} // namespace detail
mlir::Type ToyArrayType::getElementType() {
return mlir::FloatType::getF64(getContext());
}
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
ArrayRef<int64_t> shape) {
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
}
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
/// Dialect creation, the instance will be owned by the context. This is the
/// point of registration of custom types and operations for the dialect.
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
MulOp, AddOp, ReturnOp>();
addTypes<ToyArrayType>();
}
/// Parse a type registered to this dialect, we expect only Toy arrays.
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
// Sanity check: we only support array or array<...>
if (!tyData.startswith("array")) {
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
return nullptr;
}
// Drop the "array" prefix from the type name, we expect either an empty
// string or just the shape.
tyData = tyData.drop_front(StringRef("array").size());
// This is the generic array case without shape, early return it.
if (tyData.empty())
return ToyArrayType::get(getContext());
// Use a regex to parse the shape (for efficient we should store this regex in
// the dialect itself).
SmallVector<StringRef, 4> matches;
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
if (!shapeRegex.match(tyData, &matches)) {
emitError(loc, "invalid toy array shape '" + tyData + "'");
return nullptr;
}
SmallVector<int64_t, 4> shape;
// Iterate through the captures, skip the first one which is the full string.
for (auto dimStr :
llvm::make_range(std::next(matches.begin()), matches.end())) {
if (dimStr.startswith(","))
continue; // POSIX misses non-capturing groups.
if (dimStr.empty())
continue; // '*' makes it an optional group capture
// Convert the capture to an integer
unsigned long long dim;
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
return mlir::Type();
}
shape.push_back(dim);
}
// Finally we collected all the dimensions in the shape,
// create the array type.
return ToyArrayType::get(getContext(), shape);
}
/// Print a Toy array type, for example `array<2, 3, 4>`
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
auto arrayTy = type.dyn_cast<ToyArrayType>();
if (!arrayTy) {
os << "unknown toy type";
return;
}
os << "array";
if (!arrayTy.getShape().empty()) {
os << "<";
mlir::interleaveComma(arrayTy.getShape(), os);
os << ">";
}
}
////////////////////////////////////////////////////////////////////////////////
//////////////////// Custom Operations for the Dialect /////////////////////////
////////////////////////////////////////////////////////////////////////////////
/// Helper to verify that the result of an operation is a Toy array type.
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getResult()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Helper to verify that the two operands of a binary operation are Toy
/// arrays..
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its LHS, got "
<< op->getOperand(0)->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
auto dataAttribute = builder->getNamedAttr("value", value);
state.attributes.push_back(dataAttribute);
}
/// Build a constant operation.
/// The builder is passed as an argument, so is the state that this method is
/// expected to fill in order to build the operation.
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::FloatAttr value) {
// Broadcast and forward to the other build factory
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
auto dataType = builder->getTensorType({1}, elementType);
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
.cast<mlir::DenseElementsAttr>();
ConstantOp::build(builder, state, {1}, dataAttribute);
}
/// Verifier for constant operation.
mlir::LogicalResult ConstantOp::verify() {
// Ensure that the return type is a Toy array
if (failed(verifyToyReturnArray(this)))
return mlir::failure();
// We expect the constant itself to be stored as an attribute.
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
if (!dataAttr) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
if (!attrType) {
return emitOpError(
"missing valid `value` DenseElementsAttribute on toy.constant()");
}
// If the return type of the constant is not a generic array, the shape must
// match the shape of the attribute holding the data.
auto resultType = getResult()->getType().cast<ToyArrayType>();
if (!resultType.isGeneric()) {
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("The rank of the toy.constant return type must match "
"the one of the attached value attribute: " +
Twine(attrType.getRank()) +
" != " + Twine(resultType.getRank()));
}
for (int dim = 0; dim < attrType.getRank(); ++dim) {
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
std::string msg;
raw_string_ostream os(msg);
return emitOpError(
"Shape mismatch between toy.constant return type and its "
"attribute at dimension " +
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
" != " + Twine(resultType.getShape()[dim]));
}
}
}
return mlir::success();
}
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
StringRef callee, ArrayRef<mlir::Value *> arguments) {
// Generic call always returns a generic ToyArray initially
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.assign(arguments.begin(), arguments.end());
auto calleeAttr = builder->getStringAttr(callee);
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
}
mlir::LogicalResult GenericCallOp::verify() {
// Verify that every operand is a Toy Array
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its " << opId << " operand, got "
<< getOperand(opId)->getType();
return emitOpError(os.str());
}
}
return mlir::success();
}
/// Return the name of the callee.
StringRef GenericCallOp::getCalleeName() {
return getAttr("callee").cast<mlir::StringAttr>().getValue();
}
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
std::string msg;
raw_string_ostream os(msg);
os << "expects a Toy Array for its argument, got "
<< op->getOperand()->getType();
return op->emitOpError(os.str());
}
return mlir::success();
}
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Return does not return any value and has an optional single argument
if (value)
state.operands.push_back(value);
}
mlir::LogicalResult ReturnOp::verify() {
if (getNumOperands() > 1) {
std::string msg;
raw_string_ostream os(msg);
os << "expects zero or one operand, got " << getNumOperands();
return emitOpError(os.str());
}
if (hasOperand() && failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
// Print does not return any value and has a single argument
state.operands.push_back(value);
}
mlir::LogicalResult PrintOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(value);
}
mlir::LogicalResult TransposeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
return mlir::success();
}
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *value, ToyArrayType reshapedType) {
state.types.push_back(reshapedType);
state.operands.push_back(value);
}
mlir::LogicalResult ReshapeOp::verify() {
if (failed(verifyToySingleOperand(this)))
return mlir::failure();
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
if (!retTy)
return emitOpError("toy.reshape is expected to produce a Toy array");
if (retTy.isGeneric())
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
"got a generic one.");
return mlir::success();
}
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult AddOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
mlir::Value *lhs, mlir::Value *rhs) {
state.types.push_back(ToyArrayType::get(builder->getContext()));
state.operands.push_back(lhs);
state.operands.push_back(rhs);
}
mlir::LogicalResult MulOp::verify() {
if (failed(verifyToyBinOperands(this)))
return mlir::failure();
return mlir::success();
}
} // namespace toy

View File

@@ -125,7 +125,7 @@ void ASTDumper::dump(NumberExprAST *num) {
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
}
/// Helper to print recursively a literal. This handles nested array like:
/// Helper to print recurisvely a literal. This handles nested array like:
/// [ [ 1, 2 ], [ 3, 4 ] ]
/// We print out such array with the dimensions spelled out at every level:
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]

View File

@@ -22,12 +22,14 @@
#include "toy/Dialect.h"
#include "toy/MLIRGen.h"
#include "toy/Parser.h"
#include <memory>
#include "mlir/Analysis/Verifier.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/CommandLine.h"
@@ -61,6 +63,8 @@ static cl::opt<enum Action> emitAction(
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
@@ -75,9 +79,18 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
return parser.ParseModule();
}
mlir::LogicalResult optimize(mlir::ModuleOp module) {
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createCanonicalizerPass());
// Apply any generic pass manager command line options and run the pipeline.
applyPassManagerCLOptions(pm);
return pm.run(module);
}
int dumpMLIR() {
// Register our Dialect with MLIR
mlir::registerDialect<ToyDialect>();
mlir::registerDialect<mlir::toy::ToyDialect>();
mlir::MLIRContext context;
mlir::OwningModuleRef module;
@@ -106,6 +119,12 @@ int dumpMLIR() {
}
if (!module)
return 1;
if (EnableOpt) {
if (failed(optimize(*module))) {
llvm::errs() << "Module optimization failed\n";
return 7;
}
}
module->dump();
return 0;
}
@@ -125,6 +144,7 @@ int dumpAST() {
}
int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
switch (emitAction) {

View File

@@ -1,296 +1,262 @@
# Chapter 3: Defining and Registering a Dialect in MLIR
# Chapter 3: High-level Language-Specific Analysis and Transformation
In the previous chapter, we saw how to emit a custom IR for Toy in MLIR using
opaque operations. In this chapter we will register our Dialect with MLIR to
start making the Toy IR more robust and friendly to use.
Creating a dialect that closely represents the semantics of an input language
enables analyses, transformations and optimizations in MLIR that require high
level language information and are generally performed on the language AST. For
example, `clang` has a fairly
[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html)
for performing template instantiation in C++.
Dialects in MLIR allow for registering operations and types with an MLIRContext.
They also must reserve a "namespace" to avoid collision with other registered
dialects. These registered operations are no longer opaque to MLIR: for example
we can teach the MLIR verifier to enforce some invariants on the IR.
We divide compiler transformations into two categories: local and global. In
this chapter, we focus on how to leverage the Toy Dialect and its high-level
semantics to perform local pattern-match transformations that would be difficult
in LLVM. For this, we use MLIR's
[Generic DAG Rewriter](../../GenericDAGRewriter.md).
```c++
/// This is the definition of the Toy dialect. A dialect inherits from
/// mlir::Dialect and registers custom operations and types (in its constructor).
/// It can also override general behavior of dialects exposed as virtual
/// methods, for example regarding verification and parsing/printing.
class ToyDialect : public mlir::Dialect {
public:
explicit ToyDialect(mlir::MLIRContext *ctx);
There are two methods that can be used to implement pattern-match
transformations: 1. Imperative, C++ pattern-match and rewrite 2. Declarative,
rule-based pattern-match and rewrite using
[Table-driven Declarative Rewrite Rule](../../DeclarativeRewrites.md) (DRR).
Note that the use of DRR requires that the operations be defined using ODS as
described in [Chapter 2](../Ch-2.md).
/// Parse a type registered to this dialect. Overriding this method is
/// required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
mlir::Type parseType(llvm::StringRef tyData,
mlir::Location loc) const override;
# Optimize Transpose using C++ style pattern-match and rewrite
/// Print a type registered to this dialect. Overriding this method is
/// only required for dialects that have custom types.
/// Technically this is only needed to be able to round-trip to textual IR.
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
};
```
Let's start with a simple pattern and try to eliminate a sequence of two
transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the
corresponding Toy example:
The dialect can now be registered in the global registry:
```c++
mlir::registerDialect<ToyDialect>();
```
Any new `MLIRContext` created from now on will recognize the `toy` prefix when
parsing new types and invoke our `parseType` method. We will see later how to
enable custom operations, but first let's define a custom type to handle Toy
arrays.
# Custom Type Handling
As you may have noticed in the previous chapter, dialect specific types in MLIR
are serialized as strings. In the case of Toy, an example would be
`!toy.array<2, 3>`. MLIR will find the ToyDialect from the `!toy` prefix but it
is up to the dialect itself to translate the content of the string into a proper
type.
First we need to define the class representing our type. In MLIR, types are
references to immutable and uniqued objects owned by the MLIRContext. As such,
our `ToyArrayType` will only be a wrapper around a pointer to an uniqued
instance of `ToyArrayTypeStorage` in the Context and provide the public facade
API to interact with the type.
```c++
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
detail::ToyArrayTypeStorage> {
public:
/// Returns the dimensions for this Toy array, or an empty range for a generic array.
llvm::ArrayRef<int64_t> getShape();
/// Predicate to test if this array is generic (shape haven't been inferred yet).
bool isGeneric() { return getShape().empty(); }
/// Return the rank of this array (0 if it is generic)
int getRank() { return getShape().size(); }
/// Get the unique instance of this Type from the context.
/// A ToyArrayType is only defined by the shape of the array.
static ToyArrayType get(mlir::MLIRContext *context,
llvm::ArrayRef<int64_t> shape = {});
/// Support method to enable LLVM-style RTTI type casting.
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
};
```
Implementing `getShape()` for example is just about retrieving the pointer to
the uniqued instance and forwarding:
```c++
llvm::ArrayRef<int64_t> ToyArrayType::getShape() {
return getImpl()->getShape();
```Toy(.toy)
def transpose_transpose(x) {
return transpose(transpose(x));
}
```
The calls to `getImpl()` give access to the `ToyArrayTypeStorage` that holds the
information for this type. For details about how the storage of the type works,
we'll refer you to `Ch3/mlir/ToyDialect.cpp`.
Finally, the Toy dialect can register the type with MLIR, and implement some
custom parsing for our types:
```c++
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
// note the `toy` prefix that we reserve here.
: mlir::Dialect("toy", ctx) {
// Register our custom type with MLIR.
addTypes<ToyArrayType>();
}
/// Parse a type registered to this dialect, we expect only Toy arrays.
mlir::Type ToyDialect::parseType(StringRef tyData,
mlir::Location loc) const {
// Sanity check: we only support array or array<...>
if (!tyData.startswith("array")) {
getContext()->emitError(loc, "Invalid Toy type '" + tyData +
"', array expected");
return nullptr;
}
// Drop the "array" prefix from the type name, we expect either an empty
// string or just the shape.
tyData = tyData.drop_front(StringRef("array").size());
// This is the generic array case without shape, early return it.
if (tyData.empty())
return ToyArrayType::get(getContext());
// Use a regex to parse the shape (for efficient we should store this regex in
// the dialect itself).
SmallVector<StringRef, 4> matches;
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
if (!shapeRegex.match(tyData, &matches)) {
getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
return nullptr;
}
SmallVector<int64_t, 4> shape;
// Iterate through the captures, skip the first one which is the full string.
for (auto dimStr :
llvm::make_range(std::next(matches.begin()), matches.end())) {
if (dimStr.startswith(","))
continue; // POSIX misses non-capturing groups.
if (dimStr.empty())
continue; // '*' makes it an optional group capture
// Convert the capture to an integer
unsigned long long dim;
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
getContext()->emitError(loc, Twine("Couldn't parse dimension as integer, matched: ") + dimStr);
return mlir::Type();
}
shape.push_back(dim);
}
// Finally we collected all the dimensions in the shape,
// create the array type.
return ToyArrayType::get(getContext(), shape);
}
```
And we also update our IR generation from the Toy AST to use our new type
instead of an opaque one:
```c++
template <typename T> mlir::Type getType(T shape) {
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
return ToyArrayType::get(&context, shape64);
}
```
From now on, MLIR knows how to parse types that are wrapped in `!toy<...>` and
these won't be opaque anymore. The first consequence is that bogus IR with
respect to our type won't be loaded anymore:
```bash(.sh)
$ echo 'func @foo() -> !toy<"bla">' | toyc -emit=mlir -x mlir -
loc("<stdin>":1:21): error: Invalid Toy type 'bla', array expected
$ echo 'func @foo() -> !toy<"array<>">' | toyc -emit=mlir -x mlir -
loc("<stdin>":1:21): error: Invalid toy array shape '<>'
$ echo 'func @foo() -> !toy<"array<1, >">' | toyc -emit=mlir -x mlir -
loc("<stdin>":1:21): error: Invalid toy array shape '<1, >'
$ echo 'func @foo() -> !toy<"array<1, 2, 3>">' | toyc -emit=mlir -x mlir -
func @foo() -> !toy<"array<1, 3>">
```
## Defining a C++ Class for an Operation
After defining our custom type, we will register all the operations for the Toy
language. Let's walk through the creation of the `toy.generic_call` operation:
Which corresponds to the following IR:
```MLIR(.mlir)
%4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
: (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
attributes {toy.generic} {
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64>
%1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64>
"toy.return"(%1) : (tensor<*xf64>) -> ()
}
```
This operation takes a variable number of operands, all of which are expected to
be Toy arrays, and return a single result. An operation inherit from `mlir::Op`
and add some optional *traits* to customize its behavior.
This is a good example of a transformation that is trivial to match on the Toy
IR but that would be quite hard for LLVM to figure. For example today clang
can't optimize away the temporary array and the computation with the naive
transpose expressed with these loops:
```c++
class GenericCallOp
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult> {
#define N 100
#define M 100
public:
/// MLIR will use this to register the operation with the parser/printer.
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
void sink(void *);
void double_transpose(int A[N][M]) {
int B[M][N];
for(int i = 0; i < N; ++i) {
for(int j = 0; j < M; ++j) {
B[j][i] = A[i][j];
}
}
for(int i = 0; i < N; ++i) {
for(int j = 0; j < M; ++j) {
A[i][j] = B[j][i];
}
}
sink(A);
}
```
/// Operations can add custom verification beyond the traits they define.
/// We will ensure that all the operands are Toy arrays.
bool verify();
For a simple C++ approach to rewrite involving matching a tree-like pattern in
the IR and replacing it with a different set of operations, we can plug into the
MLIR `Canonicalizer` pass by implementing a `RewritePattern`:
/// Interface to the builder to allow:
/// mlir::OpBuilder::create<GenericCallOp>(...)
/// This method populate the `state` that MLIR use to create operations.
/// The `toy.generic_call` operation accepts a callee name and a list of
/// arguments for the call.
static void build(mlir::OpBuilder &builder, mlir::OperationState &state,
llvm::StringRef callee,
llvm::ArrayRef<mlir::Value *> arguments);
```c++
/// Fold transpose(transpose(x) -> transpose(x)
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// Return the name of the callee by fetching it from the attribute.
llvm::StringRef getCalleeName();
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
// Look through the input of the current transpose.
mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
private:
using Op::Op;
// Use the rewriter to perform the replacement
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
return matchSuccess();
}
};
```
and we register this operation in the `ToyDialect` constructor:
The implementation of this rewriter is in `ToyCombine.cpp`. The
[canonicalization pass](../../Canonicalization.md) applies transformations
defined by operations in a greedy, iterative manner. To ensure that the
canonicalization pass applies our new transform, we set
[hasCanonicalizer = 1](../../OpDefinitions.md#hascanonicalizer) and register the
pattern with the canonicalization framework.
```c++
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
addOperations<GenericCallOp>();
addTypes<ToyArrayType>();
// Register our patterns for rewrite by the Canonicalization framework.
void TransposeOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SimplifyRedundantTranspose>(context);
}
```
After creating classes for each of our operations, our dialect is ready and we
have now better invariants enforced in our IR, and nicer API to implement
analyses and transformations in the [next chapter](Ch-4.md).
## Using TableGen
FIXME: complete
## Revisiting the Builder API
We can now update `MLIRGen.cpp`, previously our use of the builder was very
generic and creating a call operation looked like:
```
// Calls to user-defined function are mapped to a custom call that takes
// the callee name as an attribute.
mlir::OperationState result(&context, location, "toy.generic_call");
result.types.push_back(getType(VarType{}));
result.operands = std::move(operands);
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
if (!arg)
return nullptr;
result.operands.push_back(arg);
}
auto calleeAttr = builder->getStringAttr(call.getCallee());
result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
return builder->createOperation(result)->getResult(0);
```
We replace it with this new version:
We also need to update our main file, `toyc.cpp`, to add an optimization
pipeline. In MLIR, the optimizations are ran through a `PassManager` in a
similar way to LLVM:
```c++
for (auto &expr : call.getArgs()) {
auto *arg = mlirGen(*expr);
if (!arg)
return nullptr;
operands.push_back(arg);
}
return builder->create<GenericCallOp>(location, call.getCallee(), operands)->getResult();
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createCanonicalizerPass());
```
This interface offers better type safety, with some invariant enforced at the
API level. For instance the `GenericCallOp` exposes now a `getResult()` method
that does not take any argument, while before MLIR assumed the general cases and
left open the possibility to have multiple returned values. The API was
`getResult(int resultNum)`.
Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt`
and observe our pattern in action:
# Putting It All Together
After writing a class for each of our operation and implementing custom
verifier, we try again the same example of invalid IR from the previous chapter:
```bash(.sh)
$ cat test/invalid.mlir
func @main() {
%0 = "toy.print"() : () -> !toy.array<2, 3>
```MLIR(.mlir)
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
attributes {toy.generic} {
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64>
%1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64>
"toy.return"(%1) : (tensor<*xf64>) -> ()
}
$ toyc test/invalid.mlir -emit=mlir
loc("test/invalid.mlir":2:8): error: 'toy.print' op requires a single operand
```
This time the IR is correctly rejected by the verifier!
As expected we now directly return the function argument, bypassing any
transpose operation. However one of the transpose hasn't been eliminated. That
is not ideal! What happened is that our pattern replaced the last transform with
the function input and left behind the now dead transpose input. The
Canonicalizer knows to cleanup dead operations, however MLIR conservatively
assumes that operations may have side-effects. We can fix it by adding a new
trait, `NoSideEffect`, to our `TransposeOp`:
In the [next chapter](Ch-4.md) we will leverage our new dialect to implement
some high-level language-specific analyses and transformations for the Toy
language.
```TableGen(.td):
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...}
```
Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`:
```MLIR(.mlir)
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
attributes {toy.generic} {
"toy.return"(%arg0) : (tensor<*xf64>) -> ()
}
```
Perfect! No `transpose` operation is left, the code is optimal.
In the next section, we use DRR for pattern match optimizations associated with
the Reshape op.
# Optimize Reshapes using DRR
Declarative, rule-based pattern-match and rewrite or DRR is an operation
DAG-based declarative rewriter that provides a table-based syntax for
pattern-match and rewrite rules:
```TableGen(.td):
class Pattern<
dag sourcePattern, list<dag> resultPatterns,
list<dag> additionalConstraints = [],
dag benefitsAdded = (addBenefit 0)>;
```
A redundant reshape optimization similar to SimplifyRedundantTranspose can be
expressed more simply using DRR as follows:
```TableGen(.td):
// Reshape(Reshape(x)) = x
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
(ReshapeOp $arg)>;
```
The automatically generated C++ code corresponding to each of the DRR patterns
can be found under $BUILD_DIR/projects/mlir/examples/toy/Ch3/ToyCombine.inc.
DRR also provides a method for adding argument constraints when the
transformation is conditional on some properties of the arguments and results.
An example is a transformation that eliminates reshapes when they are redundant,
i.e. when the input and output shapes are identical.
```TableGen(.td):
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
def RedundantReshapeOptPattern : Pat<
(ReshapeOp:$res $arg), (replaceWithValue $arg),
[(TypesAreIdentical $res, $arg)]>;
```
Some optimizations may require additional transformations on instruction
arguments. This is achieved using NativeCodeCall, which allows for more complex
transformations either by calling into a C++ helper function or by using inline
C++. An example of such an optimization is FoldConstantReshape, where we
optimize Reshape of a constant value by reshaping the constant in place and
eliminating the reshape operation.
```TableGen(.td):
def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
def FoldConstantReshapeOptPattern : Pat<
(ReshapeOp:$res (ConstantOp $arg)),
(ConstantOp (ReshapeConstant $arg, $res))>;
```
We demonstrate these reshape optimizations using the following
trivialReshape.toy program:
```c++
def main() {
var a<2,1> = [1, 2];
var b<2,1> = a;
var c<2,1> = b;
print(c);
}
```
```MLIR(.mlir)
module {
func @main() {
%0 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}
: () -> tensor<2xf64>
%1 = "toy.reshape"(%0) : (tensor<2xf64>) -> tensor<2x1xf64>
%2 = "toy.reshape"(%1) : (tensor<2x1xf64>) -> tensor<2x1xf64>
%3 = "toy.reshape"(%2) : (tensor<2x1xf64>) -> tensor<2x1xf64>
"toy.print"(%3) : (tensor<2x1xf64>) -> ()
"toy.return"() : () -> ()
}
}
```
We can try to run `toyc test/trivialReshape.toy -emit=mlir -opt` and observe our
pattern in action:
```MLIR(.mlir)
module {
func @main() {
%0 = "toy.constant"() {value = dense<[[1.000000e+00], [2.000000e+00]]> \
: tensor<2x1xf64>} : () -> tensor<2x1xf64>
"toy.print"(%0) : (tensor<2x1xf64>) -> ()
"toy.return"() : () -> ()
}
}
```
As expected, no reshape operations remain after canonicalization.
Further details on the declarative rewrite method can be found at
[Table-driven Declarative Rewrite Rule (DRR)](../../DeclarativeRewrites.md).

View File

@@ -13,20 +13,19 @@ def main() {
print(d);
}
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
# CHECK-NEXT: attributes {toy.generic = true} {
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
# CHECK-NEXT: }
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-LABEL: func @multiply_transpose(
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
# CHECK-NEXT: attributes {toy.generic} {
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
# CHECK-LABEL: func @main() {
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()

View File

@@ -1,11 +1,9 @@
// RUN: not toyc-ch3 %s -emit=mlir 2>&1
// This IR is not "valid":
// The following IR is not "valid":
// - toy.print should not return a value.
// - toy.print should take an argument.
// - There should be a block terminator.
// This all round-trip since this is opaque for MLIR.
func @main() {
%0 = "toy.print"() : () -> !toy.array<2, 3>
%0 = "toy.print"() : () -> tensor<2x3xf64>
}

View File

@@ -6,9 +6,9 @@ def main() {
}
# CHECK-LABEL: func @main() {
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
# CHECK-NEXT: "toy.return"() : () -> ()
# CHECK-NEXT: }