mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 18:18:09 +08:00
Update Chapter 3 to demonstrate pattern match and rewrite optimizations
This is using Table-driven Declarative Rewrite Rules (DRR), the previous version of the tutorial only showed the C++ patterns. Closes tensorflow/mlir#187 PiperOrigin-RevId: 274852321
This commit is contained in:
committed by
A. Unique TensorFlower
parent
4e85dafedd
commit
cd45b0c8d9
@@ -2,16 +2,33 @@ set(LLVM_LINK_COMPONENTS
|
||||
Support
|
||||
)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS include/toy/Ops.td)
|
||||
mlir_tablegen(include/toy/Ops.h.inc -gen-op-decls)
|
||||
mlir_tablegen(include/toy/Ops.cpp.inc -gen-op-defs)
|
||||
add_public_tablegen_target(ToyCh3OpsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td)
|
||||
mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include")
|
||||
add_public_tablegen_target(ToyCh3CombineIncGen)
|
||||
|
||||
add_toy_chapter(toyc-ch3
|
||||
toyc.cpp
|
||||
parser/AST.cpp
|
||||
mlir/MLIRGen.cpp
|
||||
mlir/ToyDialect.cpp
|
||||
mlir/Dialect.cpp
|
||||
mlir/ToyCombine.cpp
|
||||
)
|
||||
|
||||
add_dependencies(toyc-ch3 ToyCh3OpsIncGen)
|
||||
add_dependencies(toyc-ch3 ToyCh3CombineIncGen)
|
||||
include_directories(include/)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)
|
||||
target_link_libraries(toyc-ch3
|
||||
PRIVATE
|
||||
MLIRAnalysis
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRPass
|
||||
MLIRTransforms)
|
||||
|
||||
|
||||
@@ -33,10 +33,9 @@
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// A variable
|
||||
/// A variable type with shape information.
|
||||
struct VarType {
|
||||
enum { TY_FLOAT, TY_INT } elt_ty;
|
||||
std::vector<int> shape;
|
||||
std::vector<int64_t> shape;
|
||||
};
|
||||
|
||||
/// Base class for all expression nodes.
|
||||
@@ -50,9 +49,7 @@ public:
|
||||
Expr_Var,
|
||||
Expr_BinOp,
|
||||
Expr_Call,
|
||||
Expr_Print, // builtin
|
||||
Expr_If,
|
||||
Expr_For,
|
||||
Expr_Print,
|
||||
};
|
||||
|
||||
ExprAST(ExprASTKind kind, Location location)
|
||||
@@ -85,7 +82,7 @@ public:
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a literal value.
|
||||
class LiteralExprAST : public ExprAST {
|
||||
std::vector<std::unique_ptr<ExprAST>> values;
|
||||
std::vector<int64_t> dims;
|
||||
@@ -116,7 +113,7 @@ public:
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for defining a variable.
|
||||
class VarDeclExprAST : public ExprAST {
|
||||
std::string name;
|
||||
VarType type;
|
||||
@@ -136,7 +133,7 @@ public:
|
||||
static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; }
|
||||
};
|
||||
|
||||
///
|
||||
/// Expression class for a return operator.
|
||||
class ReturnExprAST : public ExprAST {
|
||||
llvm::Optional<std::unique_ptr<ExprAST>> expr;
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the IR Dialect for the Toy language.
|
||||
// See g3doc/Tutorials/Toy/Ch-3.md for more information.
|
||||
// See g3doc/Tutorials/Toy/Ch-2.md for more information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -25,311 +25,29 @@
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
class Builder;
|
||||
}
|
||||
|
||||
namespace toy {
|
||||
|
||||
/// This is the definition of the Toy dialect. A dialect inherits from
|
||||
/// mlir::Dialect and register custom operations and types (in its constructor).
|
||||
/// It can also overriding general behavior of dialects exposed as virtual
|
||||
/// method, for example regarding verification and parsing/printing.
|
||||
/// mlir::Dialect and registers custom attributes, operations, and types (in its
|
||||
/// constructor). It can also override some general behavior exposed via virtual
|
||||
/// methods.
|
||||
class ToyDialect : public mlir::Dialect {
|
||||
public:
|
||||
explicit ToyDialect(mlir::MLIRContext *ctx);
|
||||
|
||||
/// Parse a type registered to this dialect. Overriding this method is
|
||||
/// required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
mlir::Type parseType(llvm::StringRef tyData,
|
||||
mlir::Location loc) const override;
|
||||
|
||||
/// Print a type registered to this dialect. Overriding this method is
|
||||
/// only required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
/// Provide a utility accessor to the dialect namespace. This is used by
|
||||
/// several utilities for casting between dialects.
|
||||
static llvm::StringRef getDialectNamespace() { return "toy"; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
/////////////////////// Custom Types for the Dialect ///////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace detail {
|
||||
struct ToyArrayTypeStorage;
|
||||
}
|
||||
|
||||
/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa.
|
||||
enum ToyTypeKind {
|
||||
// The enum starts at the range reserved for this dialect.
|
||||
TOY_TYPE = mlir::Type::FIRST_TOY_TYPE,
|
||||
TOY_ARRAY,
|
||||
};
|
||||
|
||||
/// Type for Toy arrays.
|
||||
/// In MLIR Types are reference to immutable and uniqued objects owned by the
|
||||
/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued
|
||||
/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and
|
||||
/// provides the public facade API to interact with the type.
|
||||
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
|
||||
detail::ToyArrayTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Returns the dimensions for this array, or and empty range for a generic
|
||||
/// array.
|
||||
llvm::ArrayRef<int64_t> getShape();
|
||||
|
||||
/// Predicate to test if this array is generic (shape haven't been inferred
|
||||
/// yet).
|
||||
bool isGeneric() { return getShape().empty(); }
|
||||
|
||||
/// Return the rank of this array (0 if it is generic).
|
||||
int getRank() { return getShape().size(); }
|
||||
|
||||
/// Return the type of individual elements in the array.
|
||||
mlir::Type getElementType();
|
||||
|
||||
/// Get the unique instance of this Type from the context.
|
||||
/// A ToyArrayType is only defined by the shape of the array.
|
||||
static ToyArrayType get(mlir::MLIRContext *context,
|
||||
llvm::ArrayRef<int64_t> shape = {});
|
||||
|
||||
/// Support method to enable LLVM-style RTTI type casting.
|
||||
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Constant operation turns a literal into an SSA value. The data is attached
|
||||
/// to the operation as an attribute. For example:
|
||||
///
|
||||
/// %0 = "toy.constant"()
|
||||
/// {value: dense<tensor<2x3xf64>, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>}
|
||||
/// : () -> !toy.array<2, 3>
|
||||
///
|
||||
/// An operation inherits from `class Op` and specifies optional traits. Here we
|
||||
/// indicate that `toy.constant` does not have any operands and returns a single
|
||||
/// result. The traits provide some utilities methods for the operation, for
|
||||
/// instance we will be able to use `getResult()`, but `getOperand()` won't be
|
||||
/// available.
|
||||
class ConstantOp : public mlir::Op<ConstantOp, mlir::OpTrait::ZeroOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
/// This is the name used by MLIR to match an operation to this class during
|
||||
/// parsing.
|
||||
static llvm::StringRef getOperationName() { return "toy.constant"; }
|
||||
|
||||
/// The operation can have extra verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populates the `state` that MLIR uses to create operations.
|
||||
/// The `toy.constant` operation does not have arguments but attaches a
|
||||
/// constant array as an attribute and returns it as an SSA value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
mlir::DenseElementsAttr value);
|
||||
|
||||
/// Similar to the one above, but takes a single float and returns a
|
||||
/// !toy.array<1>.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Generic calls represent calls to a user defined function that needs to
|
||||
/// be specialized for the shape of its arguments. The callee name is attached
|
||||
/// as a literal string as an attribute. The arguments list must match the
|
||||
/// arguments expected by the callee. For example:
|
||||
///
|
||||
/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
|
||||
/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
|
||||
///
|
||||
/// This is only valid if a function named "my_func" exists and takes two
|
||||
/// arguments.
|
||||
class GenericCallOp
|
||||
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
/// MLIR will use this to register the operation with the parser/printer.
|
||||
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to the builder to allow:
|
||||
/// mlir::Builder::create<GenericCallOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.generic_call` operation accepts a callee name and a list of
|
||||
/// arguments for the call.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
llvm::StringRef callee,
|
||||
llvm::ArrayRef<mlir::Value *> arguments);
|
||||
|
||||
/// Return the name of the callee.
|
||||
llvm::StringRef getCalleeName();
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Return operations terminate blocks (and functions as well). They take a
|
||||
/// single argument and the type must match the function return type.
|
||||
class ReturnOp
|
||||
: public mlir::Op<ReturnOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::ZeroResult, mlir::OpTrait::IsTerminator> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.return"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.return` operation accepts an optional single array as an argument
|
||||
/// and does not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value = nullptr);
|
||||
|
||||
/// Return true if there is a returned value.
|
||||
bool hasOperand() { return 0 != getNumOperands(); }
|
||||
|
||||
/// Helper to return the optional operand. Caller must check if the operand
|
||||
/// is present before calling this.
|
||||
mlir::Value *getOperand() { return getOperation()->getOperand(0); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// The print builtin takes a single array argument and does not return any.
|
||||
class PrintOp : public mlir::Op<PrintOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::ZeroResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.print"; }
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.print` operation accepts a single array as argument and does
|
||||
/// not have any returned value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
class TransposeOp : public mlir::Op<TransposeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.transpose"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<TransposeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.transpose` operation accepts a single array as argument and
|
||||
/// returns the transposed array as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Reshape operation is transforming its input array into a new array with the
|
||||
/// same number of elements but different shapes. For example:
|
||||
///
|
||||
/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2>
|
||||
///
|
||||
class ReshapeOp : public mlir::Op<ReshapeOp, mlir::OpTrait::OneOperand,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.reshape"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<ReshapeOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.reshape` operation accepts a single array as argument and
|
||||
/// returns the array with the specified reshapedType as its only result.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType);
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Binary operation implementing a multiplication. For two-dimensional array
|
||||
/// a matrix multiplication is implemented, while for one dimensional array a
|
||||
/// dot product is performed.
|
||||
class MulOp : public mlir::Op<MulOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.mul"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
|
||||
/// Element wise addition of two arrays. The shape must match.
|
||||
class AddOp : public mlir::Op<AddOp, mlir::OpTrait::NOperands<2>::Impl,
|
||||
mlir::OpTrait::OneResult> {
|
||||
public:
|
||||
static llvm::StringRef getOperationName() { return "toy.add"; }
|
||||
|
||||
/// Operation can add custom verification beyond the traits they define.
|
||||
mlir::LogicalResult verify();
|
||||
|
||||
/// Interface to mlir::Builder::create<PrintOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.mul` operation accepts two operands as argument and returns
|
||||
/// a single value.
|
||||
static void build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs);
|
||||
|
||||
/// Convenience accessor for LHS of the expression.
|
||||
mlir::Value *getLHS() { return getOperand(0); }
|
||||
|
||||
/// Convenience accessor for RHS of the expression.
|
||||
mlir::Value *getRHS() { return getOperand(1); }
|
||||
|
||||
/// Inherit constructor.
|
||||
using Op::Op;
|
||||
};
|
||||
/// Include the auto-generated header file containing the declarations of the
|
||||
/// toy operations.
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.h.inc"
|
||||
|
||||
} // end namespace toy
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TUTORIAL_TOY_DIALECT_H_
|
||||
|
||||
@@ -31,7 +31,7 @@ namespace toy {
|
||||
|
||||
/// Structure definition a location in a file.
|
||||
struct Location {
|
||||
std::shared_ptr<std::string> file; ///< filename
|
||||
std::shared_ptr<std::string> file; ///< filename.
|
||||
int line; ///< line number.
|
||||
int col; ///< column number.
|
||||
};
|
||||
|
||||
247
mlir/examples/toy/Ch3/include/toy/Ops.td
Normal file
247
mlir/examples/toy/Ch3/include/toy/Ops.td
Normal file
@@ -0,0 +1,247 @@
|
||||
//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines the operations of the Toy dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifdef TOY_OPS
|
||||
#else
|
||||
#define TOY_OPS
|
||||
|
||||
#ifdef OP_BASE
|
||||
#else
|
||||
include "mlir/IR/OpBase.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
// Provide a definition of the 'toy' dialect in the ODS framework so that we
|
||||
// can define our operations.
|
||||
def Toy_Dialect : Dialect {
|
||||
let name = "toy";
|
||||
let cppNamespace = "toy";
|
||||
}
|
||||
|
||||
// Base class for toy dialect operations. This operation inherits from the base
|
||||
// `Op` class in OpBase.td, and provides:
|
||||
// * The parent dialect of the operation.
|
||||
// * The mnemonic for the operation, or the name without the dialect prefix.
|
||||
// * A list of traits for the operation.
|
||||
class Toy_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<Toy_Dialect, mnemonic, traits>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// We define a toy operation by inherting from our base 'Toy_Op' class above.
|
||||
// Here we provide the mnemonic and a list of traits for the operation. The
|
||||
// constant operation is marked as 'NoSideEffect' as it is a pure operation
|
||||
// and may be removed if dead.
|
||||
def ConstantOp : Toy_Op<"constant", [NoSideEffect]> {
|
||||
// Provide a summary and description for this operation. This can be used to
|
||||
// auto-generate documenatation of the operations within our dialect.
|
||||
let summary = "constant";
|
||||
let description = [{
|
||||
Constant operation turns a literal into an SSA value. The data is attached
|
||||
to the operation as an attribute. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.constant"()
|
||||
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
|
||||
: () -> tensor<2x3xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
// The constant operation takes an attribute as the only input.
|
||||
let arguments = (ins F64ElementsAttr:$value);
|
||||
|
||||
// The constant operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the constant operation. These method populates
|
||||
// the `state` that MLIR uses to create operations, i.e. these are used when
|
||||
// using `builder.create<ConstantOp>(...)`.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"DenseElementsAttr value", [{
|
||||
build(builder, result, value.getType(), value);
|
||||
}]>,
|
||||
|
||||
// Build a constant with a given constant floating-point value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, double value", [{
|
||||
buildConstantOp(builder, result, value);
|
||||
}]>
|
||||
];
|
||||
|
||||
// Invoke a static verify method to verify this constant operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def AddOp : Toy_Op<"add", [NoSideEffect]> {
|
||||
let summary = "element-wise addition operation";
|
||||
let description = [{
|
||||
The "add" operation performs element-wise addition between two tensors.
|
||||
The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building an AddOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildAddOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
def GenericCallOp : Toy_Op<"generic_call"> {
|
||||
let summary = "generic call operation";
|
||||
let description = [{
|
||||
Generic calls represent calls to a user defined function that needs to
|
||||
be specialized for the shape of its arguments. The callee name is attached
|
||||
as a symbol reference via an attribute. The arguments list must match the
|
||||
arguments expected by the callee. For example:
|
||||
|
||||
```mlir
|
||||
%4 = "toy.generic_call"(%1, %3) {callee = @my_func}
|
||||
: (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
```
|
||||
|
||||
This is only valid if a function named "my_func" exists and takes two
|
||||
arguments.
|
||||
}];
|
||||
|
||||
// The generic call operation takes a symbol reference attribute as the
|
||||
// callee, and inputs for the call.
|
||||
let arguments = (ins SymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);
|
||||
|
||||
// The generic call operation returns a single value of TensorType.
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Add custom build methods for the generic call operation.
|
||||
let builders = [
|
||||
// Build a constant with a given constant tensor value.
|
||||
OpBuilder<"Builder *builder, OperationState &result, "
|
||||
"StringRef callee, ArrayRef<Value *> arguments", [{
|
||||
buildGenericCallOp(builder, result, callee, arguments);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def MulOp : Toy_Op<"mul", [NoSideEffect]> {
|
||||
let summary = "element-wise multiplication operation";
|
||||
let description = [{
|
||||
The "mul" operation performs element-wise multiplication between two
|
||||
tensors. The shapes of the tensor operands are expected to match.
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Allow building a MulOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{
|
||||
buildMulOp(b, result, lhs, rhs);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
def PrintOp : Toy_Op<"print"> {
|
||||
let summary = "print operation";
|
||||
let description = [{
|
||||
The "print" builtin operation prints a given input tensor, and produces
|
||||
no results.
|
||||
}];
|
||||
|
||||
// The print operation takes an input tensor to print.
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
}
|
||||
|
||||
def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> {
|
||||
let summary = "tensor reshape operation";
|
||||
let description = [{
|
||||
Reshape operation is transforming its input tensor into a new tensor with
|
||||
the same number of elements but different shapes. For example:
|
||||
|
||||
```mlir
|
||||
%0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
|
||||
// Enabled registering canonicalization patterns with this operation.
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// We expect that the reshape operation returns a statically shaped tensor.
|
||||
let results = (outs StaticShapeTensorOf<[F64]>);
|
||||
}
|
||||
|
||||
def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> {
|
||||
let summary = "return operation";
|
||||
let description = [{
|
||||
The "return" operation represents a return operation within a function.
|
||||
The operation takes an optional tensor operand and produces no results.
|
||||
The operand type must match the signature of the function that contains
|
||||
the operation. For example:
|
||||
|
||||
```mlir
|
||||
func @foo() -> tensor<2xf64> {
|
||||
...
|
||||
toy.return %0 : tensor<2xf64>
|
||||
}
|
||||
```
|
||||
}];
|
||||
|
||||
// The return operation takes an optional input operand to return. This
|
||||
// value must match the return type of the enclosing function.
|
||||
let arguments = (ins Variadic<F64Tensor>:$input);
|
||||
|
||||
// Allow building a ReturnOp with no return operand.
|
||||
let builders = [OpBuilder<
|
||||
"Builder *b, OperationState &result", [{ build(b, result, llvm::None); }]
|
||||
>];
|
||||
|
||||
// Provide extra utility definitions on the c++ operation class definition.
|
||||
let extraClassDeclaration = [{
|
||||
bool hasOperand() { return getNumOperands() != 0; }
|
||||
}];
|
||||
|
||||
// Invoke a static verify method to verify this return operation.
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {
|
||||
let summary = "transpose operation";
|
||||
|
||||
let arguments = (ins F64Tensor:$input);
|
||||
let results = (outs F64Tensor);
|
||||
|
||||
// Enabled registering canonicalization patterns with this operation.
|
||||
let hasCanonicalizer = 1;
|
||||
|
||||
// Allow building a TransposeOp with from the two input operands.
|
||||
let builders = [
|
||||
OpBuilder<"Builder *b, OperationState &result, Value *input", [{
|
||||
buildTransposeOp(b, result, input);
|
||||
}]
|
||||
>];
|
||||
}
|
||||
|
||||
#endif // TOY_OPS
|
||||
151
mlir/examples/toy/Ch3/mlir/Dialect.cpp
Normal file
151
mlir/examples/toy/Ch3/mlir/Dialect.cpp
Normal file
@@ -0,0 +1,151 @@
|
||||
//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::toy;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ToyDialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "toy/Ops.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Toy Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
double value) {
|
||||
auto dataType = builder->getTensorType({}, builder->getF64Type());
|
||||
auto dataAttribute = DenseElementsAttr::get(dataType, value);
|
||||
ConstantOp::build(builder, state, dataType, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
static mlir::LogicalResult verify(ConstantOp op) {
|
||||
// If the return type of the constant is not an unranked tensor, the shape
|
||||
// must match the shape of the attribute holding the data.
|
||||
auto resultType = op.getResult()->getType().cast<RankedTensorType>();
|
||||
if (!resultType)
|
||||
return success();
|
||||
|
||||
auto attrType = op.value().getType().cast<mlir::TensorType>();
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return op.emitOpError(
|
||||
"return type must match the one of the attached value "
|
||||
"attribute: ")
|
||||
<< attrType.getRank() << " != " << resultType.getRank();
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
return op.emitOpError(
|
||||
"return type shape mismatches its attribute at dimension ")
|
||||
<< dim << ": " << attrType.getShape()[dim]
|
||||
<< " != " << resultType.getShape()[dim];
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
static void buildGenericCallOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, StringRef callee,
|
||||
ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns an unranked Tensor initially.
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(arguments);
|
||||
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
|
||||
}
|
||||
|
||||
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands({lhs, rhs});
|
||||
}
|
||||
|
||||
static mlir::LogicalResult verify(ReturnOp op) {
|
||||
// We know that the parent operation is a function, because of the 'HasParent'
|
||||
// trait attached to the operation definition.
|
||||
auto function = cast<FuncOp>(op.getParentOp());
|
||||
|
||||
/// ReturnOps can only have a single optional operand.
|
||||
if (op.getNumOperands() > 1)
|
||||
return op.emitOpError() << "expects at most 1 return operand";
|
||||
|
||||
// The operand number and types must match the function signature.
|
||||
const auto &results = function.getType().getResults();
|
||||
if (op.getNumOperands() != results.size())
|
||||
return op.emitOpError()
|
||||
<< "does not return the same number of values ("
|
||||
<< op.getNumOperands() << ") as the enclosing function ("
|
||||
<< results.size() << ")";
|
||||
|
||||
// If the operation does not have an input, we are done.
|
||||
if (!op.hasOperand())
|
||||
return mlir::success();
|
||||
|
||||
auto inputType = *op.operand_type_begin();
|
||||
auto resultType = results.front();
|
||||
|
||||
// Check that the result type of the function matches the operand type.
|
||||
if (inputType == resultType || inputType.isa<mlir::UnrankedTensorType>() ||
|
||||
resultType.isa<mlir::UnrankedTensorType>())
|
||||
return mlir::success();
|
||||
|
||||
return op.emitError() << "type of return operand ("
|
||||
<< *op.operand_type_begin()
|
||||
<< ") doesn't match function result type ("
|
||||
<< results.front() << ")";
|
||||
}
|
||||
|
||||
static void buildTransposeOp(mlir::Builder *builder,
|
||||
mlir::OperationState &state, mlir::Value *value) {
|
||||
state.addTypes(builder->getTensorType(builder->getF64Type()));
|
||||
state.addOperands(value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "toy/Ops.cpp.inc"
|
||||
@@ -25,30 +25,30 @@
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/Dialect/StandardOps/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/ScopedHashTable.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir::toy;
|
||||
using namespace toy;
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::cast;
|
||||
using llvm::dyn_cast;
|
||||
using llvm::isa;
|
||||
using llvm::makeArrayRef;
|
||||
using llvm::ScopedHashTableScope;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
using std::make_unique;
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -57,56 +57,43 @@ namespace {
|
||||
/// This will emit operations that are specific to the Toy language, preserving
|
||||
/// the semantics of the language and (hopefully) allow to perform accurate
|
||||
/// analysis and transformation based on these high level semantics.
|
||||
///
|
||||
/// At this point we take advantage of the "raw" MLIR APIs to create operations
|
||||
/// that haven't been registered in any way with MLIR. These operations are
|
||||
/// unknown to MLIR, custom passes could operate by string-matching the name of
|
||||
/// these operations, but no other type checking or semantic is associated with
|
||||
/// them natively by MLIR.
|
||||
class MLIRGenImpl {
|
||||
public:
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : context(context) {}
|
||||
MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {}
|
||||
|
||||
/// Public API: convert the AST for a Toy module (source file) to an MLIR
|
||||
/// Module.
|
||||
mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
|
||||
/// Module operation.
|
||||
mlir::ModuleOp mlirGen(ModuleAST &moduleAST) {
|
||||
// We create an empty MLIR module and codegen functions one at a time and
|
||||
// add them to the module.
|
||||
theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
theModule = mlir::ModuleOp::create(builder.getUnknownLoc());
|
||||
|
||||
for (FunctionAST &F : moduleAST) {
|
||||
auto func = mlirGen(F);
|
||||
if (!func)
|
||||
return nullptr;
|
||||
theModule->push_back(func);
|
||||
theModule.push_back(func);
|
||||
}
|
||||
|
||||
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
|
||||
// this won't do much, but it should at least check some structural
|
||||
// properties.
|
||||
if (failed(mlir::verify(*theModule))) {
|
||||
emitError(mlir::UnknownLoc::get(&context), "module verification error");
|
||||
// Verify the module after we have finished constructing it, this will check
|
||||
// the structural properties of the IR and invoke any specific verifiers we
|
||||
// have on the Toy operations.
|
||||
if (failed(mlir::verify(theModule))) {
|
||||
theModule.emitError("module verification error");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return std::move(theModule);
|
||||
return theModule;
|
||||
}
|
||||
|
||||
private:
|
||||
/// In MLIR (like in LLVM) a "context" object holds the memory allocation and
|
||||
/// the ownership of many internal structure of the IR and provide a level
|
||||
/// of "uniquing" across multiple modules (types for instance).
|
||||
mlir::MLIRContext &context;
|
||||
/// A "module" matches a Toy source file: containing a list of functions.
|
||||
mlir::ModuleOp theModule;
|
||||
|
||||
/// A "module" matches a source file: it contains a list of functions.
|
||||
mlir::OwningModuleRef theModule;
|
||||
|
||||
/// The builder is a helper class to create IR inside a function. It is
|
||||
/// re-initialized every time we enter a function and kept around as a
|
||||
/// convenience for emitting individual operations.
|
||||
/// The builder is stateful, in particular it keeeps an "insertion point":
|
||||
/// this is where the next operations will be introduced.
|
||||
std::unique_ptr<mlir::OpBuilder> builder;
|
||||
/// The builder is a helper class to create IR inside a function. The builder
|
||||
/// is stateful, in particular it keeeps an "insertion point": this is where
|
||||
/// the next operations will be introduced.
|
||||
mlir::OpBuilder builder;
|
||||
|
||||
/// The symbol table maps a variable name to a value in the current scope.
|
||||
/// Entering a function creates a new scope, and the function arguments are
|
||||
@@ -116,37 +103,35 @@ private:
|
||||
|
||||
/// Helper conversion for a Toy AST location to an MLIR location.
|
||||
mlir::Location loc(Location loc) {
|
||||
return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context),
|
||||
loc.line, loc.col, &context);
|
||||
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
|
||||
loc.col);
|
||||
}
|
||||
|
||||
/// Declare a variable in the current scope, return true if the variable
|
||||
/// Declare a variable in the current scope, return success if the variable
|
||||
/// wasn't declared yet.
|
||||
bool declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var)) {
|
||||
return false;
|
||||
}
|
||||
mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) {
|
||||
if (symbolTable.count(var))
|
||||
return mlir::failure();
|
||||
symbolTable.insert(var, value);
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Create the prototype for an MLIR function with as many arguments as the
|
||||
/// provided Toy AST prototype.
|
||||
mlir::FuncOp mlirGen(PrototypeAST &proto) {
|
||||
auto location = loc(proto.loc());
|
||||
|
||||
// This is a generic function, the return type will be inferred later.
|
||||
llvm::SmallVector<mlir::Type, 4> ret_types;
|
||||
// Arguments type is uniformly a generic array.
|
||||
// Arguments type are uniformly unranked tensors.
|
||||
llvm::SmallVector<mlir::Type, 4> arg_types(proto.getArgs().size(),
|
||||
getType(VarType{}));
|
||||
auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context);
|
||||
auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(),
|
||||
func_type, /* attrs = */ {});
|
||||
auto func_type = builder.getFunctionType(arg_types, llvm::None);
|
||||
auto function = mlir::FuncOp::create(location, proto.getName(), func_type);
|
||||
|
||||
// Mark the function as generic: it'll require type specialization for every
|
||||
// call site.
|
||||
if (function.getNumArguments())
|
||||
function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context));
|
||||
|
||||
function.setAttr("toy.generic", builder.getUnitAttr());
|
||||
return function;
|
||||
}
|
||||
|
||||
@@ -165,18 +150,22 @@ private:
|
||||
// argument list as the function itself.
|
||||
auto &entryBlock = *function.addEntryBlock();
|
||||
auto &protoArgs = funcAST.getProto()->getArgs();
|
||||
|
||||
// Declare all the function arguments in the symbol table.
|
||||
for (const auto &name_value :
|
||||
llvm::zip(protoArgs, entryBlock.getArguments())) {
|
||||
declare(std::get<0>(name_value)->getName(), std::get<1>(name_value));
|
||||
if (failed(declare(std::get<0>(name_value)->getName(),
|
||||
std::get<1>(name_value))))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Create a builder for the function, it will be used throughout the codegen
|
||||
// to create operations in this function.
|
||||
builder = std::make_unique<mlir::OpBuilder>(function.getBody());
|
||||
// Set the insertion point in the builder to the beginning of the function
|
||||
// body, it will be used throughout the codegen to create operations in this
|
||||
// function.
|
||||
builder.setInsertionPointToStart(&entryBlock);
|
||||
|
||||
// Emit the body of the function.
|
||||
if (!mlirGen(*funcAST.getBody())) {
|
||||
if (mlir::failed(mlirGen(*funcAST.getBody()))) {
|
||||
function.erase();
|
||||
return nullptr;
|
||||
}
|
||||
@@ -184,10 +173,16 @@ private:
|
||||
// Implicitly return void if no return statement was emitted.
|
||||
// FIXME: we may fix the parser instead to always return the last expression
|
||||
// (this would possibly help the REPL case later)
|
||||
if (function.getBlocks().back().back().getName().getStringRef() !=
|
||||
"toy.return") {
|
||||
ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None);
|
||||
mlirGen(fakeRet);
|
||||
ReturnOp returnOp;
|
||||
if (!entryBlock.empty())
|
||||
returnOp = dyn_cast<ReturnOp>(entryBlock.back());
|
||||
if (!returnOp) {
|
||||
builder.create<ReturnOp>(loc(funcAST.getProto()->loc()));
|
||||
} else if (returnOp.hasOperand()) {
|
||||
// Otherwise, if this return operation has an operand then add a result to
|
||||
// the function.
|
||||
function.setType(builder.getFunctionType(function.getType().getInputs(),
|
||||
getType(VarType{})));
|
||||
}
|
||||
|
||||
return function;
|
||||
@@ -206,11 +201,11 @@ private:
|
||||
// and the result value is returned. If an error occurs we get a nullptr
|
||||
// and propagate.
|
||||
//
|
||||
mlir::Value *L = mlirGen(*binop.getLHS());
|
||||
if (!L)
|
||||
mlir::Value *lhs = mlirGen(*binop.getLHS());
|
||||
if (!lhs)
|
||||
return nullptr;
|
||||
mlir::Value *R = mlirGen(*binop.getRHS());
|
||||
if (!R)
|
||||
mlir::Value *rhs = mlirGen(*binop.getRHS());
|
||||
if (!rhs)
|
||||
return nullptr;
|
||||
auto location = loc(binop.loc());
|
||||
|
||||
@@ -218,123 +213,113 @@ private:
|
||||
// support '+' and '*'.
|
||||
switch (binop.getOp()) {
|
||||
case '+':
|
||||
return builder->create<AddOp>(location, L, R).getResult();
|
||||
break;
|
||||
return builder.create<AddOp>(location, lhs, rhs);
|
||||
case '*':
|
||||
return builder->create<MulOp>(location, L, R).getResult();
|
||||
default:
|
||||
emitError(loc(binop.loc()), "error: invalid binary operator '")
|
||||
<< binop.getOp() << "'";
|
||||
return nullptr;
|
||||
return builder.create<MulOp>(location, lhs, rhs);
|
||||
}
|
||||
|
||||
emitError(location, "invalid binary operator '") << binop.getOp() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// This is a reference to a variable in an expression. The variable is
|
||||
// expected to have been declared and so should have a value in the symbol
|
||||
// table, otherwise emit an error and return nullptr.
|
||||
/// This is a reference to a variable in an expression. The variable is
|
||||
/// expected to have been declared and so should have a value in the symbol
|
||||
/// table, otherwise emit an error and return nullptr.
|
||||
mlir::Value *mlirGen(VariableExprAST &expr) {
|
||||
if (symbolTable.count(expr.getName()))
|
||||
return symbolTable.lookup(expr.getName());
|
||||
if (auto *variable = symbolTable.lookup(expr.getName()))
|
||||
return variable;
|
||||
|
||||
emitError(loc(expr.loc()), "error: unknown variable '")
|
||||
<< expr.getName() << "'";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Emit a return operation, return true on success.
|
||||
bool mlirGen(ReturnExprAST &ret) {
|
||||
/// Emit a return operation. This will return failure if any generation fails.
|
||||
mlir::LogicalResult mlirGen(ReturnExprAST &ret) {
|
||||
auto location = loc(ret.loc());
|
||||
// `return` takes an optional expression, we need to account for it here.
|
||||
if (!ret.getExpr().hasValue()) {
|
||||
builder->create<ReturnOp>(location);
|
||||
return true;
|
||||
|
||||
// 'return' takes an optional expression, handle that case here.
|
||||
mlir::Value *expr = nullptr;
|
||||
if (ret.getExpr().hasValue()) {
|
||||
if (!(expr = mlirGen(*ret.getExpr().getValue())))
|
||||
return mlir::failure();
|
||||
}
|
||||
auto *expr = mlirGen(*ret.getExpr().getValue());
|
||||
if (!expr)
|
||||
return false;
|
||||
builder->create<ReturnOp>(location, expr);
|
||||
return true;
|
||||
|
||||
// Otherwise, this return operation has zero operands.
|
||||
builder.create<ReturnOp>(location, expr ? makeArrayRef(expr)
|
||||
: ArrayRef<mlir::Value *>());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
// data in an Attribute attached to a `toy.constant` operation.
|
||||
// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
// Here is an excerpt:
|
||||
//
|
||||
// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
// places where a variable is never allowed [...]. They consist of a name
|
||||
// and a [concrete attribute value](#attribute-values). It is possible to
|
||||
// attach attributes to operations, functions, and function arguments. The
|
||||
// set of expected attributes, their structure, and their interpretation
|
||||
// are all contextually dependent on what they are attached to.
|
||||
//
|
||||
// Example, the source level statement:
|
||||
// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
// will be converted to:
|
||||
// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64>
|
||||
//
|
||||
/// Emit a literal/constant array. It will be emitted as a flattened array of
|
||||
/// data in an Attribute attached to a `toy.constant` operation.
|
||||
/// See documentation on [Attributes](LangRef.md#attributes) for more details.
|
||||
/// Here is an excerpt:
|
||||
///
|
||||
/// Attributes are the mechanism for specifying constant data in MLIR in
|
||||
/// places where a variable is never allowed [...]. They consist of a name
|
||||
/// and a concrete attribute value. The set of expected attributes, their
|
||||
/// structure, and their interpretation are all contextually dependent on
|
||||
/// what they are attached to.
|
||||
///
|
||||
/// Example, the source level statement:
|
||||
/// var a<2, 3> = [[1, 2, 3], [4, 5, 6]];
|
||||
/// will be converted to:
|
||||
/// %0 = "toy.constant"() {value: dense<tensor<2x3xf64>,
|
||||
/// [[1.000000e+00, 2.000000e+00, 3.000000e+00],
|
||||
/// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64>
|
||||
///
|
||||
mlir::Value *mlirGen(LiteralExprAST &lit) {
|
||||
auto location = loc(lit.loc());
|
||||
// The attribute is a vector with an attribute per element (number) in the
|
||||
// array, see `collectData()` below for more details.
|
||||
std::vector<mlir::Attribute> data;
|
||||
auto type = getType(lit.getDims());
|
||||
|
||||
// The attribute is a vector with a floating point value per element
|
||||
// (number) in the array, see `collectData()` below for more details.
|
||||
std::vector<double> data;
|
||||
data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1,
|
||||
std::multiplies<int>()));
|
||||
collectData(lit, data);
|
||||
|
||||
// FIXME: using a tensor type is a HACK here.
|
||||
// Can we do differently without registering a dialect? Using a string blob?
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto dataType = builder->getTensorType(lit.getDims(), elementType);
|
||||
// The type of this attribute is tensor of 64-bit floating-point with the
|
||||
// shape of the literal.
|
||||
mlir::Type elementType = builder.getF64Type();
|
||||
auto dataType = builder.getTensorType(lit.getDims(), elementType);
|
||||
|
||||
// This is the actual attribute that actually hold the list of values for
|
||||
// this array literal.
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, data)
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
// This is the actual attribute that holds the list of values for this
|
||||
// tensor literal.
|
||||
auto dataAttribute =
|
||||
mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data));
|
||||
|
||||
// Build the MLIR op `toy.constant`, only boilerplate below.
|
||||
return builder->create<ConstantOp>(location, lit.getDims(), dataAttribute)
|
||||
.getResult();
|
||||
// Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build`
|
||||
// method.
|
||||
return builder.create<ConstantOp>(loc(lit.loc()), type, dataAttribute);
|
||||
}
|
||||
|
||||
// Recursive helper function to accumulate the data that compose an array
|
||||
// literal. It flattens the nested structure in the supplied vector. For
|
||||
// example with this array:
|
||||
// [[1, 2], [3, 4]]
|
||||
// we will generate:
|
||||
// [ 1, 2, 3, 4 ]
|
||||
// Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`.
|
||||
// Attributes are the way MLIR attaches constant to operations and functions.
|
||||
void collectData(ExprAST &expr, std::vector<mlir::Attribute> &data) {
|
||||
/// Recursive helper function to accumulate the data that compose an array
|
||||
/// literal. It flattens the nested structure in the supplied vector. For
|
||||
/// example with this array:
|
||||
/// [[1, 2], [3, 4]]
|
||||
/// we will generate:
|
||||
/// [ 1, 2, 3, 4 ]
|
||||
/// Individual numbers are represented as doubles.
|
||||
/// Attributes are the way MLIR attaches constant to operations.
|
||||
void collectData(ExprAST &expr, std::vector<double> &data) {
|
||||
if (auto *lit = dyn_cast<LiteralExprAST>(&expr)) {
|
||||
for (auto &value : lit->getValues())
|
||||
collectData(*value, data);
|
||||
return;
|
||||
}
|
||||
|
||||
assert(isa<NumberExprAST>(expr) && "expected literal or number expr");
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(
|
||||
elementType, cast<NumberExprAST>(expr).getValue(), loc(expr.loc()));
|
||||
data.push_back(attr);
|
||||
data.push_back(cast<NumberExprAST>(expr).getValue());
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for the `transpose`
|
||||
// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
/// Emit a call expression. It emits specific operations for the `transpose`
|
||||
/// builtin. Other identifiers are assumed to be user-defined functions.
|
||||
mlir::Value *mlirGen(CallExprAST &call) {
|
||||
llvm::StringRef callee = call.getCallee();
|
||||
auto location = loc(call.loc());
|
||||
std::string callee = call.getCallee();
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
mlir::Value *arg = mlirGen(*call.getArgs()[0]);
|
||||
return builder->create<TransposeOp>(location, arg).getResult();
|
||||
}
|
||||
|
||||
// Codegen the operands first
|
||||
// Codegen the operands first.
|
||||
SmallVector<mlir::Value *, 4> operands;
|
||||
for (auto &expr : call.getArgs()) {
|
||||
auto *arg = mlirGen(*expr);
|
||||
@@ -342,34 +327,41 @@ private:
|
||||
return nullptr;
|
||||
operands.push_back(arg);
|
||||
}
|
||||
// Calls to user-defined function are mapped to a custom call that takes
|
||||
// the callee name as an attribute.
|
||||
return builder->create<GenericCallOp>(location, call.getCallee(), operands)
|
||||
.getResult();
|
||||
|
||||
// Builting calls have their custom operation, meaning this is a
|
||||
// straightforward emission.
|
||||
if (callee == "transpose") {
|
||||
if (call.getArgs().size() != 1) {
|
||||
emitError(location, "MLIR codegen encountered an error: toy.transpose "
|
||||
"does not accept multiple arguments");
|
||||
return nullptr;
|
||||
}
|
||||
return builder.create<TransposeOp>(location, operands[0]);
|
||||
}
|
||||
|
||||
// Otherwise this is a call to a user-defined function. Calls to ser-defined
|
||||
// functions are mapped to a custom call that takes the callee name as an
|
||||
// attribute.
|
||||
return builder.create<GenericCallOp>(location, callee, operands);
|
||||
}
|
||||
|
||||
// Emit a call expression. It emits specific operations for two builtins:
|
||||
// transpose(x) and print(x). Other identifiers are assumed to be user-defined
|
||||
// functions. Return false on failure.
|
||||
bool mlirGen(PrintExprAST &call) {
|
||||
/// Emit a print expression. It emits specific operations for two builtins:
|
||||
/// transpose(x) and print(x).
|
||||
mlir::LogicalResult mlirGen(PrintExprAST &call) {
|
||||
auto *arg = mlirGen(*call.getArg());
|
||||
if (!arg)
|
||||
return false;
|
||||
auto location = loc(call.loc());
|
||||
builder->create<PrintOp>(location, arg);
|
||||
return true;
|
||||
return mlir::failure();
|
||||
|
||||
builder.create<PrintOp>(loc(call.loc()), arg);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
/// Emit a constant for a single number (FIXME: semantic? broadcast?)
|
||||
mlir::Value *mlirGen(NumberExprAST &num) {
|
||||
auto location = loc(num.loc());
|
||||
mlir::Type elementType = mlir::FloatType::getF64(&context);
|
||||
auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(),
|
||||
loc(num.loc()));
|
||||
return builder->create<ConstantOp>(location, attr).getResult();
|
||||
return builder.create<ConstantOp>(loc(num.loc()), num.getValue());
|
||||
}
|
||||
|
||||
// Dispatch codegen for the right expression subclass using RTTI.
|
||||
/// Dispatch codegen for the right expression subclass using RTTI.
|
||||
mlir::Value *mlirGen(ExprAST &expr) {
|
||||
switch (expr.getKind()) {
|
||||
case toy::ExprAST::Expr_BinOp:
|
||||
@@ -390,77 +382,75 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
// initializer and record the value in the symbol table before returning it.
|
||||
// Future expressions will be able to reference this variable through symbol
|
||||
// table lookup.
|
||||
/// Handle a variable declaration, we'll codegen the expression that forms the
|
||||
/// initializer and record the value in the symbol table before returning it.
|
||||
/// Future expressions will be able to reference this variable through symbol
|
||||
/// table lookup.
|
||||
mlir::Value *mlirGen(VarDeclExprAST &vardecl) {
|
||||
mlir::Value *value = nullptr;
|
||||
auto location = loc(vardecl.loc());
|
||||
if (auto init = vardecl.getInitVal()) {
|
||||
value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder
|
||||
->create<ReshapeOp>(
|
||||
location, value,
|
||||
getType(vardecl.getType()).cast<ToyArrayType>())
|
||||
.getResult();
|
||||
}
|
||||
} else {
|
||||
auto init = vardecl.getInitVal();
|
||||
if (!init) {
|
||||
emitError(loc(vardecl.loc()),
|
||||
"missing initializer in variable declaration");
|
||||
return nullptr;
|
||||
}
|
||||
// Register the value in the symbol table
|
||||
declare(vardecl.getName(), value);
|
||||
|
||||
mlir::Value *value = mlirGen(*init);
|
||||
if (!value)
|
||||
return nullptr;
|
||||
|
||||
// We have the initializer value, but in case the variable was declared
|
||||
// with specific shape, we emit a "reshape" operation. It will get
|
||||
// optimized out later as needed.
|
||||
if (!vardecl.getType().shape.empty()) {
|
||||
value = builder.create<ReshapeOp>(loc(vardecl.loc()),
|
||||
getType(vardecl.getType()), value);
|
||||
}
|
||||
|
||||
// Register the value in the symbol table.
|
||||
if (failed(declare(vardecl.getName(), value)))
|
||||
return nullptr;
|
||||
return value;
|
||||
}
|
||||
|
||||
/// Codegen a list of expression, return false if one of them hit an error.
|
||||
bool mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<llvm::StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
/// Codegen a list of expression, return failure if one of them hit an error.
|
||||
mlir::LogicalResult mlirGen(ExprASTList &blockAST) {
|
||||
ScopedHashTableScope<StringRef, mlir::Value *> var_scope(symbolTable);
|
||||
for (auto &expr : blockAST) {
|
||||
// Specific handling for variable declarations, return statement, and
|
||||
// print. These can only appear in block list and not in nested
|
||||
// expressions.
|
||||
if (auto *vardecl = dyn_cast<VarDeclExprAST>(expr.get())) {
|
||||
if (!mlirGen(*vardecl))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
continue;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get())) {
|
||||
if (!mlirGen(*ret))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
if (auto *ret = dyn_cast<ReturnExprAST>(expr.get()))
|
||||
return mlirGen(*ret);
|
||||
if (auto *print = dyn_cast<PrintExprAST>(expr.get())) {
|
||||
if (!mlirGen(*print))
|
||||
return false;
|
||||
if (mlir::failed(mlirGen(*print)))
|
||||
return mlir::success();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generic expression dispatch codegen.
|
||||
if (!mlirGen(*expr))
|
||||
return false;
|
||||
return mlir::failure();
|
||||
}
|
||||
return true;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a type from a list of shape dimensions. Types are `array` followed
|
||||
/// by an optional dimension list, example: array<2, 2>
|
||||
/// They are wrapped in a `toy` dialect (see next chapter) and get printed:
|
||||
/// !toy.array<2, 2>
|
||||
template <typename T> mlir::Type getType(T shape) {
|
||||
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
|
||||
return ToyArrayType::get(&context, shape64);
|
||||
/// Build a tensor type from a list of shape dimensions.
|
||||
mlir::Type getType(ArrayRef<int64_t> shape) {
|
||||
// If the shape is empty, then this type is unranked.
|
||||
if (shape.empty())
|
||||
return builder.getTensorType(builder.getF64Type());
|
||||
|
||||
// Otherwise, we use the given shape.
|
||||
return builder.getTensorType(shape, builder.getF64Type());
|
||||
}
|
||||
|
||||
/// Build an MLIR type from a Toy AST variable type
|
||||
/// (forward to the generic getType(T) above).
|
||||
/// Build an MLIR type from a Toy AST variable type (forward to the generic
|
||||
/// getType above).
|
||||
mlir::Type getType(const VarType &type) { return getType(type.shape); }
|
||||
};
|
||||
|
||||
|
||||
75
mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
Normal file
75
mlir/examples/toy/Ch3/mlir/ToyCombine.cpp
Normal file
@@ -0,0 +1,75 @@
|
||||
//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements a simple combiner for optimizing pattern in the Toy
|
||||
// dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "toy/Dialect.h"
|
||||
#include <numeric>
|
||||
using namespace mlir;
|
||||
using namespace toy;
|
||||
|
||||
namespace {
|
||||
/// Include the patterns defined in the Declarative Rewrite framework.
|
||||
#include "ToyCombine.inc"
|
||||
} // end anonymous namespace
|
||||
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
/// them in order of profitability.
|
||||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// This method attempts to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
mlir::Value *transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
|
||||
/// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
|
||||
FoldConstantReshapeOptPattern>(context);
|
||||
}
|
||||
72
mlir/examples/toy/Ch3/mlir/ToyCombine.td
Normal file
72
mlir/examples/toy/Ch3/mlir/ToyCombine.td
Normal file
@@ -0,0 +1,72 @@
|
||||
//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Defines language-specific pattern match optimizations for Toy using
|
||||
// Declarative Rewrite Rules (DRR) specified using TableGen records.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TOY_COMBINE
|
||||
#define TOY_COMBINE
|
||||
|
||||
#ifndef OP_BASE
|
||||
include "toy/Ops.td"
|
||||
#endif // OP_BASE
|
||||
|
||||
/* Pattern-Match and Rewrite using DRR:
|
||||
class Pattern<
|
||||
dag sourcePattern, list<dag> resultPatterns,
|
||||
list<dag> additionalConstraints = [],
|
||||
dag benefitsAdded = (addBenefit 0)>;
|
||||
*/
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Basic Pattern-Match and Rewrite
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Reshape(Reshape(x)) = x
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
||||
(ReshapeOp $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite using Native Code Call
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Native Code Calls may be used for more complex transformations using inline
|
||||
// C++ and C++ helper functions.
|
||||
|
||||
// Reshape(Constant(x)) = x'
|
||||
def ReshapeConstant :
|
||||
NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pattern-Match and Rewrite with Constraints
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// DRR allows for constraint checking when the transformation is conditional
|
||||
// on operand properties.
|
||||
|
||||
// Reshape(x) = x, where input and output shapes are identical
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
|
||||
#endif // TOY_COMBINE
|
||||
@@ -1,390 +0,0 @@
|
||||
//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements the dialect for the Toy IR: custom type parsing and
|
||||
// operation verification.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "toy/Dialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/iterator_range.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::raw_string_ostream;
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
namespace toy {
|
||||
namespace detail {
|
||||
|
||||
/// This class holds the implementation of the ToyArrayType.
|
||||
/// It is intended to be uniqued based on its content and owned by the context.
|
||||
struct ToyArrayTypeStorage : public mlir::TypeStorage {
|
||||
/// This defines how we unique this type in the context: our key contains
|
||||
/// only the shape, a more complex type would have multiple entries in the
|
||||
/// tuple here.
|
||||
/// The element of the tuples usually matches 1-1 the arguments from the
|
||||
/// public `get()` method arguments from the facade.
|
||||
using KeyTy = std::tuple<ArrayRef<int64_t>>;
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(std::get<0>(key));
|
||||
}
|
||||
/// When the key hash hits an existing type, we compare the shape themselves
|
||||
/// to confirm we have the right type.
|
||||
bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); }
|
||||
|
||||
/// This is a factory method to create our type storage. It is only
|
||||
/// invoked after looking up the type in the context using the key and not
|
||||
/// finding it.
|
||||
static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
|
||||
const KeyTy &key) {
|
||||
// Copy the shape array into the bumpptr allocator owned by the context.
|
||||
ArrayRef<int64_t> shape = allocator.copyInto(std::get<0>(key));
|
||||
|
||||
// Allocate the instance for the ToyArrayTypeStorage itself
|
||||
auto *storage = allocator.allocate<ToyArrayTypeStorage>();
|
||||
// Initialize the instance using placement new.
|
||||
return new (storage) ToyArrayTypeStorage(shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return shape; }
|
||||
|
||||
private:
|
||||
ArrayRef<int64_t> shape;
|
||||
|
||||
/// Constructor is only invoked from the `construct()` method above.
|
||||
ToyArrayTypeStorage(ArrayRef<int64_t> shape) : shape(shape) {}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
mlir::Type ToyArrayType::getElementType() {
|
||||
return mlir::FloatType::getF64(getContext());
|
||||
}
|
||||
|
||||
ToyArrayType ToyArrayType::get(mlir::MLIRContext *context,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return Base::get(context, ToyTypeKind::TOY_ARRAY, shape);
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ToyArrayType::getShape() { return getImpl()->getShape(); }
|
||||
|
||||
/// Dialect creation, the instance will be owned by the context. This is the
|
||||
/// point of registration of custom types and operations for the dialect.
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<ConstantOp, GenericCallOp, PrintOp, TransposeOp, ReshapeOp,
|
||||
MulOp, AddOp, ReturnOp>();
|
||||
addTypes<ToyArrayType>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect, we expect only Toy arrays.
|
||||
mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const {
|
||||
// Sanity check: we only support array or array<...>
|
||||
if (!tyData.startswith("array")) {
|
||||
emitError(loc, "invalid Toy type '" + tyData + "', array expected");
|
||||
return nullptr;
|
||||
}
|
||||
// Drop the "array" prefix from the type name, we expect either an empty
|
||||
// string or just the shape.
|
||||
tyData = tyData.drop_front(StringRef("array").size());
|
||||
// This is the generic array case without shape, early return it.
|
||||
if (tyData.empty())
|
||||
return ToyArrayType::get(getContext());
|
||||
|
||||
// Use a regex to parse the shape (for efficient we should store this regex in
|
||||
// the dialect itself).
|
||||
SmallVector<StringRef, 4> matches;
|
||||
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
|
||||
if (!shapeRegex.match(tyData, &matches)) {
|
||||
emitError(loc, "invalid toy array shape '" + tyData + "'");
|
||||
return nullptr;
|
||||
}
|
||||
SmallVector<int64_t, 4> shape;
|
||||
// Iterate through the captures, skip the first one which is the full string.
|
||||
for (auto dimStr :
|
||||
llvm::make_range(std::next(matches.begin()), matches.end())) {
|
||||
if (dimStr.startswith(","))
|
||||
continue; // POSIX misses non-capturing groups.
|
||||
if (dimStr.empty())
|
||||
continue; // '*' makes it an optional group capture
|
||||
// Convert the capture to an integer
|
||||
unsigned long long dim;
|
||||
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
|
||||
emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr);
|
||||
return mlir::Type();
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
// Finally we collected all the dimensions in the shape,
|
||||
// create the array type.
|
||||
return ToyArrayType::get(getContext(), shape);
|
||||
}
|
||||
|
||||
/// Print a Toy array type, for example `array<2, 3, 4>`
|
||||
void ToyDialect::printType(mlir::Type type, raw_ostream &os) const {
|
||||
auto arrayTy = type.dyn_cast<ToyArrayType>();
|
||||
if (!arrayTy) {
|
||||
os << "unknown toy type";
|
||||
return;
|
||||
}
|
||||
os << "array";
|
||||
if (!arrayTy.getShape().empty()) {
|
||||
os << "<";
|
||||
mlir::interleaveComma(arrayTy.getShape(), os);
|
||||
os << ">";
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////// Custom Operations for the Dialect /////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Helper to verify that the result of an operation is a Toy array type.
|
||||
template <typename T> static mlir::LogicalResult verifyToyReturnArray(T *op) {
|
||||
if (!op->getResult()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getResult()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Helper to verify that the two operands of a binary operation are Toy
|
||||
/// arrays..
|
||||
template <typename T> static mlir::LogicalResult verifyToyBinOperands(T *op) {
|
||||
if (!op->getOperand(0)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
if (!op->getOperand(1)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its LHS, got "
|
||||
<< op->getOperand(0)->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
ArrayRef<int64_t> shape, mlir::DenseElementsAttr value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext(), shape));
|
||||
auto dataAttribute = builder->getNamedAttr("value", value);
|
||||
state.attributes.push_back(dataAttribute);
|
||||
}
|
||||
|
||||
/// Build a constant operation.
|
||||
/// The builder is passed as an argument, so is the state that this method is
|
||||
/// expected to fill in order to build the operation.
|
||||
void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::FloatAttr value) {
|
||||
// Broadcast and forward to the other build factory
|
||||
mlir::Type elementType = mlir::FloatType::getF64(builder->getContext());
|
||||
auto dataType = builder->getTensorType({1}, elementType);
|
||||
auto dataAttribute = builder->getDenseElementsAttr(dataType, {value})
|
||||
.cast<mlir::DenseElementsAttr>();
|
||||
|
||||
ConstantOp::build(builder, state, {1}, dataAttribute);
|
||||
}
|
||||
|
||||
/// Verifier for constant operation.
|
||||
mlir::LogicalResult ConstantOp::verify() {
|
||||
// Ensure that the return type is a Toy array
|
||||
if (failed(verifyToyReturnArray(this)))
|
||||
return mlir::failure();
|
||||
|
||||
// We expect the constant itself to be stored as an attribute.
|
||||
auto dataAttr = getAttr("value").dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (!dataAttr) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
auto attrType = dataAttr.getType().dyn_cast<mlir::TensorType>();
|
||||
if (!attrType) {
|
||||
return emitOpError(
|
||||
"missing valid `value` DenseElementsAttribute on toy.constant()");
|
||||
}
|
||||
|
||||
// If the return type of the constant is not a generic array, the shape must
|
||||
// match the shape of the attribute holding the data.
|
||||
auto resultType = getResult()->getType().cast<ToyArrayType>();
|
||||
if (!resultType.isGeneric()) {
|
||||
if (attrType.getRank() != resultType.getRank()) {
|
||||
return emitOpError("The rank of the toy.constant return type must match "
|
||||
"the one of the attached value attribute: " +
|
||||
Twine(attrType.getRank()) +
|
||||
" != " + Twine(resultType.getRank()));
|
||||
}
|
||||
for (int dim = 0; dim < attrType.getRank(); ++dim) {
|
||||
if (attrType.getShape()[dim] != resultType.getShape()[dim]) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
return emitOpError(
|
||||
"Shape mismatch between toy.constant return type and its "
|
||||
"attribute at dimension " +
|
||||
Twine(dim) + ": " + Twine(attrType.getShape()[dim]) +
|
||||
" != " + Twine(resultType.getShape()[dim]));
|
||||
}
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
StringRef callee, ArrayRef<mlir::Value *> arguments) {
|
||||
// Generic call always returns a generic ToyArray initially
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.assign(arguments.begin(), arguments.end());
|
||||
auto calleeAttr = builder->getStringAttr(callee);
|
||||
state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
|
||||
}
|
||||
|
||||
mlir::LogicalResult GenericCallOp::verify() {
|
||||
// Verify that every operand is a Toy Array
|
||||
for (int opId = 0, num = getNumOperands(); opId < num; ++opId) {
|
||||
if (!getOperand(opId)->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its " << opId << " operand, got "
|
||||
<< getOperand(opId)->getType();
|
||||
return emitOpError(os.str());
|
||||
}
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
/// Return the name of the callee.
|
||||
StringRef GenericCallOp::getCalleeName() {
|
||||
return getAttr("callee").cast<mlir::StringAttr>().getValue();
|
||||
}
|
||||
|
||||
template <typename T> static mlir::LogicalResult verifyToySingleOperand(T *op) {
|
||||
if (!op->getOperand()->getType().template isa<ToyArrayType>()) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects a Toy Array for its argument, got "
|
||||
<< op->getOperand()->getType();
|
||||
return op->emitOpError(os.str());
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Return does not return any value and has an optional single argument
|
||||
if (value)
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReturnOp::verify() {
|
||||
if (getNumOperands() > 1) {
|
||||
std::string msg;
|
||||
raw_string_ostream os(msg);
|
||||
os << "expects zero or one operand, got " << getNumOperands();
|
||||
return emitOpError(os.str());
|
||||
}
|
||||
if (hasOperand() && failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
// Print does not return any value and has a single argument
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult PrintOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult TransposeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *value, ToyArrayType reshapedType) {
|
||||
state.types.push_back(reshapedType);
|
||||
state.operands.push_back(value);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ReshapeOp::verify() {
|
||||
if (failed(verifyToySingleOperand(this)))
|
||||
return mlir::failure();
|
||||
auto retTy = getResult()->getType().dyn_cast<ToyArrayType>();
|
||||
if (!retTy)
|
||||
return emitOpError("toy.reshape is expected to produce a Toy array");
|
||||
if (retTy.isGeneric())
|
||||
return emitOpError("toy.reshape is expected to produce a shaped Toy array, "
|
||||
"got a generic one.");
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void AddOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult AddOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void MulOp::build(mlir::Builder *builder, mlir::OperationState &state,
|
||||
mlir::Value *lhs, mlir::Value *rhs) {
|
||||
state.types.push_back(ToyArrayType::get(builder->getContext()));
|
||||
state.operands.push_back(lhs);
|
||||
state.operands.push_back(rhs);
|
||||
}
|
||||
|
||||
mlir::LogicalResult MulOp::verify() {
|
||||
if (failed(verifyToyBinOperands(this)))
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace toy
|
||||
@@ -125,7 +125,7 @@ void ASTDumper::dump(NumberExprAST *num) {
|
||||
llvm::errs() << num->getValue() << " " << loc(num) << "\n";
|
||||
}
|
||||
|
||||
/// Helper to print recursively a literal. This handles nested array like:
|
||||
/// Helper to print recurisvely a literal. This handles nested array like:
|
||||
/// [ [ 1, 2 ], [ 3, 4 ] ]
|
||||
/// We print out such array with the dimensions spelled out at every level:
|
||||
/// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ]
|
||||
|
||||
@@ -22,12 +22,14 @@
|
||||
#include "toy/Dialect.h"
|
||||
#include "toy/MLIRGen.h"
|
||||
#include "toy/Parser.h"
|
||||
#include <memory>
|
||||
|
||||
#include "mlir/Analysis/Verifier.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
@@ -61,6 +63,8 @@ static cl::opt<enum Action> emitAction(
|
||||
cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
||||
cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")));
|
||||
|
||||
static cl::opt<bool> EnableOpt("opt", cl::desc("Enable optimizations"));
|
||||
|
||||
/// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
||||
std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> FileOrErr =
|
||||
@@ -75,9 +79,18 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
||||
return parser.ParseModule();
|
||||
}
|
||||
|
||||
mlir::LogicalResult optimize(mlir::ModuleOp module) {
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
|
||||
// Apply any generic pass manager command line options and run the pipeline.
|
||||
applyPassManagerCLOptions(pm);
|
||||
return pm.run(module);
|
||||
}
|
||||
|
||||
int dumpMLIR() {
|
||||
// Register our Dialect with MLIR
|
||||
mlir::registerDialect<ToyDialect>();
|
||||
mlir::registerDialect<mlir::toy::ToyDialect>();
|
||||
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module;
|
||||
@@ -106,6 +119,12 @@ int dumpMLIR() {
|
||||
}
|
||||
if (!module)
|
||||
return 1;
|
||||
if (EnableOpt) {
|
||||
if (failed(optimize(*module))) {
|
||||
llvm::errs() << "Module optimization failed\n";
|
||||
return 7;
|
||||
}
|
||||
}
|
||||
module->dump();
|
||||
return 0;
|
||||
}
|
||||
@@ -125,6 +144,7 @@ int dumpAST() {
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerPassManagerCLOptions();
|
||||
cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
|
||||
|
||||
switch (emitAction) {
|
||||
|
||||
@@ -1,296 +1,262 @@
|
||||
# Chapter 3: Defining and Registering a Dialect in MLIR
|
||||
# Chapter 3: High-level Language-Specific Analysis and Transformation
|
||||
|
||||
In the previous chapter, we saw how to emit a custom IR for Toy in MLIR using
|
||||
opaque operations. In this chapter we will register our Dialect with MLIR to
|
||||
start making the Toy IR more robust and friendly to use.
|
||||
Creating a dialect that closely represents the semantics of an input language
|
||||
enables analyses, transformations and optimizations in MLIR that require high
|
||||
level language information and are generally performed on the language AST. For
|
||||
example, `clang` has a fairly
|
||||
[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html)
|
||||
for performing template instantiation in C++.
|
||||
|
||||
Dialects in MLIR allow for registering operations and types with an MLIRContext.
|
||||
They also must reserve a "namespace" to avoid collision with other registered
|
||||
dialects. These registered operations are no longer opaque to MLIR: for example
|
||||
we can teach the MLIR verifier to enforce some invariants on the IR.
|
||||
We divide compiler transformations into two categories: local and global. In
|
||||
this chapter, we focus on how to leverage the Toy Dialect and its high-level
|
||||
semantics to perform local pattern-match transformations that would be difficult
|
||||
in LLVM. For this, we use MLIR's
|
||||
[Generic DAG Rewriter](../../GenericDAGRewriter.md).
|
||||
|
||||
```c++
|
||||
/// This is the definition of the Toy dialect. A dialect inherits from
|
||||
/// mlir::Dialect and registers custom operations and types (in its constructor).
|
||||
/// It can also override general behavior of dialects exposed as virtual
|
||||
/// methods, for example regarding verification and parsing/printing.
|
||||
class ToyDialect : public mlir::Dialect {
|
||||
public:
|
||||
explicit ToyDialect(mlir::MLIRContext *ctx);
|
||||
There are two methods that can be used to implement pattern-match
|
||||
transformations: 1. Imperative, C++ pattern-match and rewrite 2. Declarative,
|
||||
rule-based pattern-match and rewrite using
|
||||
[Table-driven Declarative Rewrite Rule](../../DeclarativeRewrites.md) (DRR).
|
||||
Note that the use of DRR requires that the operations be defined using ODS as
|
||||
described in [Chapter 2](../Ch-2.md).
|
||||
|
||||
/// Parse a type registered to this dialect. Overriding this method is
|
||||
/// required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
mlir::Type parseType(llvm::StringRef tyData,
|
||||
mlir::Location loc) const override;
|
||||
# Optimize Transpose using C++ style pattern-match and rewrite
|
||||
|
||||
/// Print a type registered to this dialect. Overriding this method is
|
||||
/// only required for dialects that have custom types.
|
||||
/// Technically this is only needed to be able to round-trip to textual IR.
|
||||
void printType(mlir::Type type, llvm::raw_ostream &os) const override;
|
||||
};
|
||||
```
|
||||
Let's start with a simple pattern and try to eliminate a sequence of two
|
||||
transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the
|
||||
corresponding Toy example:
|
||||
|
||||
The dialect can now be registered in the global registry:
|
||||
|
||||
```c++
|
||||
mlir::registerDialect<ToyDialect>();
|
||||
```
|
||||
|
||||
Any new `MLIRContext` created from now on will recognize the `toy` prefix when
|
||||
parsing new types and invoke our `parseType` method. We will see later how to
|
||||
enable custom operations, but first let's define a custom type to handle Toy
|
||||
arrays.
|
||||
|
||||
# Custom Type Handling
|
||||
|
||||
As you may have noticed in the previous chapter, dialect specific types in MLIR
|
||||
are serialized as strings. In the case of Toy, an example would be
|
||||
`!toy.array<2, 3>`. MLIR will find the ToyDialect from the `!toy` prefix but it
|
||||
is up to the dialect itself to translate the content of the string into a proper
|
||||
type.
|
||||
|
||||
First we need to define the class representing our type. In MLIR, types are
|
||||
references to immutable and uniqued objects owned by the MLIRContext. As such,
|
||||
our `ToyArrayType` will only be a wrapper around a pointer to an uniqued
|
||||
instance of `ToyArrayTypeStorage` in the Context and provide the public facade
|
||||
API to interact with the type.
|
||||
|
||||
```c++
|
||||
class ToyArrayType : public mlir::Type::TypeBase<ToyArrayType, mlir::Type,
|
||||
detail::ToyArrayTypeStorage> {
|
||||
public:
|
||||
/// Returns the dimensions for this Toy array, or an empty range for a generic array.
|
||||
llvm::ArrayRef<int64_t> getShape();
|
||||
|
||||
/// Predicate to test if this array is generic (shape haven't been inferred yet).
|
||||
bool isGeneric() { return getShape().empty(); }
|
||||
|
||||
/// Return the rank of this array (0 if it is generic)
|
||||
int getRank() { return getShape().size(); }
|
||||
|
||||
/// Get the unique instance of this Type from the context.
|
||||
/// A ToyArrayType is only defined by the shape of the array.
|
||||
static ToyArrayType get(mlir::MLIRContext *context,
|
||||
llvm::ArrayRef<int64_t> shape = {});
|
||||
|
||||
/// Support method to enable LLVM-style RTTI type casting.
|
||||
static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; }
|
||||
};
|
||||
```
|
||||
|
||||
Implementing `getShape()` for example is just about retrieving the pointer to
|
||||
the uniqued instance and forwarding:
|
||||
|
||||
```c++
|
||||
llvm::ArrayRef<int64_t> ToyArrayType::getShape() {
|
||||
return getImpl()->getShape();
|
||||
```Toy(.toy)
|
||||
def transpose_transpose(x) {
|
||||
return transpose(transpose(x));
|
||||
}
|
||||
```
|
||||
|
||||
The calls to `getImpl()` give access to the `ToyArrayTypeStorage` that holds the
|
||||
information for this type. For details about how the storage of the type works,
|
||||
we'll refer you to `Ch3/mlir/ToyDialect.cpp`.
|
||||
|
||||
Finally, the Toy dialect can register the type with MLIR, and implement some
|
||||
custom parsing for our types:
|
||||
|
||||
```c++
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx)
|
||||
// note the `toy` prefix that we reserve here.
|
||||
: mlir::Dialect("toy", ctx) {
|
||||
// Register our custom type with MLIR.
|
||||
addTypes<ToyArrayType>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect, we expect only Toy arrays.
|
||||
mlir::Type ToyDialect::parseType(StringRef tyData,
|
||||
mlir::Location loc) const {
|
||||
// Sanity check: we only support array or array<...>
|
||||
if (!tyData.startswith("array")) {
|
||||
getContext()->emitError(loc, "Invalid Toy type '" + tyData +
|
||||
"', array expected");
|
||||
return nullptr;
|
||||
}
|
||||
// Drop the "array" prefix from the type name, we expect either an empty
|
||||
// string or just the shape.
|
||||
tyData = tyData.drop_front(StringRef("array").size());
|
||||
// This is the generic array case without shape, early return it.
|
||||
if (tyData.empty())
|
||||
return ToyArrayType::get(getContext());
|
||||
|
||||
// Use a regex to parse the shape (for efficient we should store this regex in
|
||||
// the dialect itself).
|
||||
SmallVector<StringRef, 4> matches;
|
||||
auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$");
|
||||
if (!shapeRegex.match(tyData, &matches)) {
|
||||
getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'");
|
||||
return nullptr;
|
||||
}
|
||||
SmallVector<int64_t, 4> shape;
|
||||
// Iterate through the captures, skip the first one which is the full string.
|
||||
for (auto dimStr :
|
||||
llvm::make_range(std::next(matches.begin()), matches.end())) {
|
||||
if (dimStr.startswith(","))
|
||||
continue; // POSIX misses non-capturing groups.
|
||||
if (dimStr.empty())
|
||||
continue; // '*' makes it an optional group capture
|
||||
// Convert the capture to an integer
|
||||
unsigned long long dim;
|
||||
if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) {
|
||||
getContext()->emitError(loc, Twine("Couldn't parse dimension as integer, matched: ") + dimStr);
|
||||
return mlir::Type();
|
||||
}
|
||||
shape.push_back(dim);
|
||||
}
|
||||
// Finally we collected all the dimensions in the shape,
|
||||
// create the array type.
|
||||
return ToyArrayType::get(getContext(), shape);
|
||||
}
|
||||
```
|
||||
|
||||
And we also update our IR generation from the Toy AST to use our new type
|
||||
instead of an opaque one:
|
||||
|
||||
```c++
|
||||
template <typename T> mlir::Type getType(T shape) {
|
||||
SmallVector<int64_t, 8> shape64(shape.begin(), shape.end());
|
||||
return ToyArrayType::get(&context, shape64);
|
||||
}
|
||||
```
|
||||
|
||||
From now on, MLIR knows how to parse types that are wrapped in `!toy<...>` and
|
||||
these won't be opaque anymore. The first consequence is that bogus IR with
|
||||
respect to our type won't be loaded anymore:
|
||||
|
||||
```bash(.sh)
|
||||
$ echo 'func @foo() -> !toy<"bla">' | toyc -emit=mlir -x mlir -
|
||||
loc("<stdin>":1:21): error: Invalid Toy type 'bla', array expected
|
||||
$ echo 'func @foo() -> !toy<"array<>">' | toyc -emit=mlir -x mlir -
|
||||
loc("<stdin>":1:21): error: Invalid toy array shape '<>'
|
||||
$ echo 'func @foo() -> !toy<"array<1, >">' | toyc -emit=mlir -x mlir -
|
||||
loc("<stdin>":1:21): error: Invalid toy array shape '<1, >'
|
||||
$ echo 'func @foo() -> !toy<"array<1, 2, 3>">' | toyc -emit=mlir -x mlir -
|
||||
func @foo() -> !toy<"array<1, 3>">
|
||||
```
|
||||
|
||||
## Defining a C++ Class for an Operation
|
||||
|
||||
After defining our custom type, we will register all the operations for the Toy
|
||||
language. Let's walk through the creation of the `toy.generic_call` operation:
|
||||
Which corresponds to the following IR:
|
||||
|
||||
```MLIR(.mlir)
|
||||
%4 = "toy.generic_call"(%1, %3) {callee: "my_func"}
|
||||
: (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array">
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
|
||||
attributes {toy.generic} {
|
||||
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
"toy.return"(%1) : (tensor<*xf64>) -> ()
|
||||
}
|
||||
```
|
||||
|
||||
This operation takes a variable number of operands, all of which are expected to
|
||||
be Toy arrays, and return a single result. An operation inherit from `mlir::Op`
|
||||
and add some optional *traits* to customize its behavior.
|
||||
This is a good example of a transformation that is trivial to match on the Toy
|
||||
IR but that would be quite hard for LLVM to figure. For example today clang
|
||||
can't optimize away the temporary array and the computation with the naive
|
||||
transpose expressed with these loops:
|
||||
|
||||
```c++
|
||||
class GenericCallOp
|
||||
: public mlir::Op<GenericCallOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult> {
|
||||
#define N 100
|
||||
#define M 100
|
||||
|
||||
public:
|
||||
/// MLIR will use this to register the operation with the parser/printer.
|
||||
static llvm::StringRef getOperationName() { return "toy.generic_call"; }
|
||||
void sink(void *);
|
||||
void double_transpose(int A[N][M]) {
|
||||
int B[M][N];
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for(int j = 0; j < M; ++j) {
|
||||
B[j][i] = A[i][j];
|
||||
}
|
||||
}
|
||||
for(int i = 0; i < N; ++i) {
|
||||
for(int j = 0; j < M; ++j) {
|
||||
A[i][j] = B[j][i];
|
||||
}
|
||||
}
|
||||
sink(A);
|
||||
}
|
||||
```
|
||||
|
||||
/// Operations can add custom verification beyond the traits they define.
|
||||
/// We will ensure that all the operands are Toy arrays.
|
||||
bool verify();
|
||||
For a simple C++ approach to rewrite involving matching a tree-like pattern in
|
||||
the IR and replacing it with a different set of operations, we can plug into the
|
||||
MLIR `Canonicalizer` pass by implementing a `RewritePattern`:
|
||||
|
||||
/// Interface to the builder to allow:
|
||||
/// mlir::OpBuilder::create<GenericCallOp>(...)
|
||||
/// This method populate the `state` that MLIR use to create operations.
|
||||
/// The `toy.generic_call` operation accepts a callee name and a list of
|
||||
/// arguments for the call.
|
||||
static void build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
llvm::StringRef callee,
|
||||
llvm::ArrayRef<mlir::Value *> arguments);
|
||||
```c++
|
||||
/// Fold transpose(transpose(x) -> transpose(x)
|
||||
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
|
||||
/// We register this pattern to match every toy.transpose in the IR.
|
||||
/// The "benefit" is used by the framework to order the patterns and process
|
||||
/// them in order of profitability.
|
||||
SimplifyRedundantTranspose(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
|
||||
|
||||
/// Return the name of the callee by fetching it from the attribute.
|
||||
llvm::StringRef getCalleeName();
|
||||
/// This method is attempting to match a pattern and rewrite it. The rewriter
|
||||
/// argument is the orchestrator of the sequence of rewrites. It is expected
|
||||
/// to interact with it to perform any changes to the IR from here.
|
||||
mlir::PatternMatchResult
|
||||
matchAndRewrite(TransposeOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
// Look through the input of the current transpose.
|
||||
mlir::Value *transposeInput = op.getOperand();
|
||||
TransposeOp transposeInputOp =
|
||||
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
|
||||
// If the input is defined by another Transpose, bingo!
|
||||
if (!transposeInputOp)
|
||||
return matchFailure();
|
||||
|
||||
private:
|
||||
using Op::Op;
|
||||
// Use the rewriter to perform the replacement
|
||||
rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp});
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
```
|
||||
|
||||
and we register this operation in the `ToyDialect` constructor:
|
||||
The implementation of this rewriter is in `ToyCombine.cpp`. The
|
||||
[canonicalization pass](../../Canonicalization.md) applies transformations
|
||||
defined by operations in a greedy, iterative manner. To ensure that the
|
||||
canonicalization pass applies our new transform, we set
|
||||
[hasCanonicalizer = 1](../../OpDefinitions.md#hascanonicalizer) and register the
|
||||
pattern with the canonicalization framework.
|
||||
|
||||
```c++
|
||||
ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) {
|
||||
addOperations<GenericCallOp>();
|
||||
addTypes<ToyArrayType>();
|
||||
// Register our patterns for rewrite by the Canonicalization framework.
|
||||
void TransposeOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<SimplifyRedundantTranspose>(context);
|
||||
}
|
||||
```
|
||||
|
||||
After creating classes for each of our operations, our dialect is ready and we
|
||||
have now better invariants enforced in our IR, and nicer API to implement
|
||||
analyses and transformations in the [next chapter](Ch-4.md).
|
||||
|
||||
## Using TableGen
|
||||
|
||||
FIXME: complete
|
||||
|
||||
## Revisiting the Builder API
|
||||
|
||||
We can now update `MLIRGen.cpp`, previously our use of the builder was very
|
||||
generic and creating a call operation looked like:
|
||||
|
||||
```
|
||||
// Calls to user-defined function are mapped to a custom call that takes
|
||||
// the callee name as an attribute.
|
||||
mlir::OperationState result(&context, location, "toy.generic_call");
|
||||
result.types.push_back(getType(VarType{}));
|
||||
result.operands = std::move(operands);
|
||||
for (auto &expr : call.getArgs()) {
|
||||
auto *arg = mlirGen(*expr);
|
||||
if (!arg)
|
||||
return nullptr;
|
||||
result.operands.push_back(arg);
|
||||
}
|
||||
auto calleeAttr = builder->getStringAttr(call.getCallee());
|
||||
result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr));
|
||||
return builder->createOperation(result)->getResult(0);
|
||||
```
|
||||
|
||||
We replace it with this new version:
|
||||
We also need to update our main file, `toyc.cpp`, to add an optimization
|
||||
pipeline. In MLIR, the optimizations are ran through a `PassManager` in a
|
||||
similar way to LLVM:
|
||||
|
||||
```c++
|
||||
for (auto &expr : call.getArgs()) {
|
||||
auto *arg = mlirGen(*expr);
|
||||
if (!arg)
|
||||
return nullptr;
|
||||
operands.push_back(arg);
|
||||
}
|
||||
return builder->create<GenericCallOp>(location, call.getCallee(), operands)->getResult();
|
||||
mlir::PassManager pm(module.getContext());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
```
|
||||
|
||||
This interface offers better type safety, with some invariant enforced at the
|
||||
API level. For instance the `GenericCallOp` exposes now a `getResult()` method
|
||||
that does not take any argument, while before MLIR assumed the general cases and
|
||||
left open the possibility to have multiple returned values. The API was
|
||||
`getResult(int resultNum)`.
|
||||
Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt`
|
||||
and observe our pattern in action:
|
||||
|
||||
# Putting It All Together
|
||||
|
||||
After writing a class for each of our operation and implementing custom
|
||||
verifier, we try again the same example of invalid IR from the previous chapter:
|
||||
|
||||
```bash(.sh)
|
||||
$ cat test/invalid.mlir
|
||||
func @main() {
|
||||
%0 = "toy.print"() : () -> !toy.array<2, 3>
|
||||
```MLIR(.mlir)
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
|
||||
attributes {toy.generic} {
|
||||
%0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
%1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
"toy.return"(%1) : (tensor<*xf64>) -> ()
|
||||
}
|
||||
$ toyc test/invalid.mlir -emit=mlir
|
||||
loc("test/invalid.mlir":2:8): error: 'toy.print' op requires a single operand
|
||||
```
|
||||
|
||||
This time the IR is correctly rejected by the verifier!
|
||||
As expected we now directly return the function argument, bypassing any
|
||||
transpose operation. However one of the transpose hasn't been eliminated. That
|
||||
is not ideal! What happened is that our pattern replaced the last transform with
|
||||
the function input and left behind the now dead transpose input. The
|
||||
Canonicalizer knows to cleanup dead operations, however MLIR conservatively
|
||||
assumes that operations may have side-effects. We can fix it by adding a new
|
||||
trait, `NoSideEffect`, to our `TransposeOp`:
|
||||
|
||||
In the [next chapter](Ch-4.md) we will leverage our new dialect to implement
|
||||
some high-level language-specific analyses and transformations for the Toy
|
||||
language.
|
||||
```TableGen(.td):
|
||||
def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...}
|
||||
```
|
||||
|
||||
Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`:
|
||||
|
||||
```MLIR(.mlir)
|
||||
func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64>
|
||||
attributes {toy.generic} {
|
||||
"toy.return"(%arg0) : (tensor<*xf64>) -> ()
|
||||
}
|
||||
```
|
||||
|
||||
Perfect! No `transpose` operation is left, the code is optimal.
|
||||
|
||||
In the next section, we use DRR for pattern match optimizations associated with
|
||||
the Reshape op.
|
||||
|
||||
# Optimize Reshapes using DRR
|
||||
|
||||
Declarative, rule-based pattern-match and rewrite or DRR is an operation
|
||||
DAG-based declarative rewriter that provides a table-based syntax for
|
||||
pattern-match and rewrite rules:
|
||||
|
||||
```TableGen(.td):
|
||||
class Pattern<
|
||||
dag sourcePattern, list<dag> resultPatterns,
|
||||
list<dag> additionalConstraints = [],
|
||||
dag benefitsAdded = (addBenefit 0)>;
|
||||
```
|
||||
|
||||
A redundant reshape optimization similar to SimplifyRedundantTranspose can be
|
||||
expressed more simply using DRR as follows:
|
||||
|
||||
```TableGen(.td):
|
||||
// Reshape(Reshape(x)) = x
|
||||
def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)),
|
||||
(ReshapeOp $arg)>;
|
||||
```
|
||||
|
||||
The automatically generated C++ code corresponding to each of the DRR patterns
|
||||
can be found under $BUILD_DIR/projects/mlir/examples/toy/Ch3/ToyCombine.inc.
|
||||
|
||||
DRR also provides a method for adding argument constraints when the
|
||||
transformation is conditional on some properties of the arguments and results.
|
||||
An example is a transformation that eliminates reshapes when they are redundant,
|
||||
i.e. when the input and output shapes are identical.
|
||||
|
||||
```TableGen(.td):
|
||||
def TypesAreIdentical : Constraint<CPred<"$0->getType() == $1->getType()">>;
|
||||
def RedundantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res $arg), (replaceWithValue $arg),
|
||||
[(TypesAreIdentical $res, $arg)]>;
|
||||
```
|
||||
|
||||
Some optimizations may require additional transformations on instruction
|
||||
arguments. This is achieved using NativeCodeCall, which allows for more complex
|
||||
transformations either by calling into a C++ helper function or by using inline
|
||||
C++. An example of such an optimization is FoldConstantReshape, where we
|
||||
optimize Reshape of a constant value by reshaping the constant in place and
|
||||
eliminating the reshape operation.
|
||||
|
||||
```TableGen(.td):
|
||||
def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast<ShapedType>())">;
|
||||
def FoldConstantReshapeOptPattern : Pat<
|
||||
(ReshapeOp:$res (ConstantOp $arg)),
|
||||
(ConstantOp (ReshapeConstant $arg, $res))>;
|
||||
```
|
||||
|
||||
We demonstrate these reshape optimizations using the following
|
||||
trivialReshape.toy program:
|
||||
|
||||
```c++
|
||||
def main() {
|
||||
var a<2,1> = [1, 2];
|
||||
var b<2,1> = a;
|
||||
var c<2,1> = b;
|
||||
print(c);
|
||||
}
|
||||
```
|
||||
|
||||
```MLIR(.mlir)
|
||||
module {
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>}
|
||||
: () -> tensor<2xf64>
|
||||
%1 = "toy.reshape"(%0) : (tensor<2xf64>) -> tensor<2x1xf64>
|
||||
%2 = "toy.reshape"(%1) : (tensor<2x1xf64>) -> tensor<2x1xf64>
|
||||
%3 = "toy.reshape"(%2) : (tensor<2x1xf64>) -> tensor<2x1xf64>
|
||||
"toy.print"(%3) : (tensor<2x1xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
We can try to run `toyc test/trivialReshape.toy -emit=mlir -opt` and observe our
|
||||
pattern in action:
|
||||
|
||||
```MLIR(.mlir)
|
||||
module {
|
||||
func @main() {
|
||||
%0 = "toy.constant"() {value = dense<[[1.000000e+00], [2.000000e+00]]> \
|
||||
: tensor<2x1xf64>} : () -> tensor<2x1xf64>
|
||||
"toy.print"(%0) : (tensor<2x1xf64>) -> ()
|
||||
"toy.return"() : () -> ()
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
As expected, no reshape operations remain after canonicalization.
|
||||
|
||||
Further details on the declarative rewrite method can be found at
|
||||
[Table-driven Declarative Rewrite Rule (DRR)](../../DeclarativeRewrites.md).
|
||||
|
||||
@@ -13,20 +13,19 @@ def main() {
|
||||
print(d);
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array)
|
||||
# CHECK-NEXT: attributes {toy.generic = true} {
|
||||
# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array
|
||||
# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array
|
||||
# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6>
|
||||
# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3>
|
||||
# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array
|
||||
# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-LABEL: func @multiply_transpose(
|
||||
# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>)
|
||||
# CHECK-NEXT: attributes {toy.generic} {
|
||||
# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> ()
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
|
||||
# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64>
|
||||
# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
|
||||
# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
// RUN: not toyc-ch3 %s -emit=mlir 2>&1
|
||||
|
||||
|
||||
// This IR is not "valid":
|
||||
// The following IR is not "valid":
|
||||
// - toy.print should not return a value.
|
||||
// - toy.print should take an argument.
|
||||
// - There should be a block terminator.
|
||||
// This all round-trip since this is opaque for MLIR.
|
||||
func @main() {
|
||||
%0 = "toy.print"() : () -> !toy.array<2, 3>
|
||||
%0 = "toy.print"() : () -> tensor<2x3xf64>
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ def main() {
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main() {
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> ()
|
||||
# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<f64>} : () -> tensor<f64>
|
||||
# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor<f64>) -> tensor<2x2xf64>
|
||||
# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> ()
|
||||
# CHECK-NEXT: "toy.return"() : () -> ()
|
||||
# CHECK-NEXT: }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user