From cd45b0c8d9f0eac7e76a892a2f2e993b340e90b5 Mon Sep 17 00:00:00 2001 From: Sana Damani Date: Tue, 15 Oct 2019 11:40:12 -0700 Subject: [PATCH] Update Chapter 3 to demonstrate pattern match and rewrite optimizations This is using Table-driven Declarative Rewrite Rules (DRR), the previous version of the tutorial only showed the C++ patterns. Closes tensorflow/mlir#187 PiperOrigin-RevId: 274852321 --- mlir/examples/toy/Ch3/CMakeLists.txt | 19 +- mlir/examples/toy/Ch3/include/toy/AST.h | 15 +- mlir/examples/toy/Ch3/include/toy/Dialect.h | 306 +------------ mlir/examples/toy/Ch3/include/toy/Lexer.h | 2 +- mlir/examples/toy/Ch3/include/toy/Ops.td | 247 ++++++++++ mlir/examples/toy/Ch3/mlir/Dialect.cpp | 151 ++++++ mlir/examples/toy/Ch3/mlir/MLIRGen.cpp | 408 ++++++++--------- mlir/examples/toy/Ch3/mlir/ToyCombine.cpp | 75 +++ mlir/examples/toy/Ch3/mlir/ToyCombine.td | 72 +++ mlir/examples/toy/Ch3/mlir/ToyDialect.cpp | 390 ---------------- mlir/examples/toy/Ch3/parser/AST.cpp | 2 +- mlir/examples/toy/Ch3/toyc.cpp | 24 +- mlir/g3doc/Tutorials/Toy/Ch-3.md | 484 +++++++++----------- mlir/test/Examples/Toy/Ch3/codegen.toy | 31 +- mlir/test/Examples/Toy/Ch3/invalid.mlir | 6 +- mlir/test/Examples/Toy/Ch3/scalar.toy | 6 +- 16 files changed, 1049 insertions(+), 1189 deletions(-) create mode 100644 mlir/examples/toy/Ch3/include/toy/Ops.td create mode 100644 mlir/examples/toy/Ch3/mlir/Dialect.cpp create mode 100644 mlir/examples/toy/Ch3/mlir/ToyCombine.cpp create mode 100644 mlir/examples/toy/Ch3/mlir/ToyCombine.td delete mode 100644 mlir/examples/toy/Ch3/mlir/ToyDialect.cpp diff --git a/mlir/examples/toy/Ch3/CMakeLists.txt b/mlir/examples/toy/Ch3/CMakeLists.txt index 060f3dd26ecf..d1b462bb2895 100644 --- a/mlir/examples/toy/Ch3/CMakeLists.txt +++ b/mlir/examples/toy/Ch3/CMakeLists.txt @@ -2,16 +2,33 @@ set(LLVM_LINK_COMPONENTS Support ) +set(LLVM_TARGET_DEFINITIONS include/toy/Ops.td) +mlir_tablegen(include/toy/Ops.h.inc -gen-op-decls) +mlir_tablegen(include/toy/Ops.cpp.inc -gen-op-defs) +add_public_tablegen_target(ToyCh3OpsIncGen) + +set(LLVM_TARGET_DEFINITIONS mlir/ToyCombine.td) +mlir_tablegen(ToyCombine.inc -gen-rewriters "-I${CMAKE_CURRENT_SOURCE_DIR}/include") +add_public_tablegen_target(ToyCh3CombineIncGen) + add_toy_chapter(toyc-ch3 toyc.cpp parser/AST.cpp mlir/MLIRGen.cpp - mlir/ToyDialect.cpp + mlir/Dialect.cpp + mlir/ToyCombine.cpp ) + +add_dependencies(toyc-ch3 ToyCh3OpsIncGen) +add_dependencies(toyc-ch3 ToyCh3CombineIncGen) include_directories(include/) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) target_link_libraries(toyc-ch3 PRIVATE MLIRAnalysis MLIRIR MLIRParser + MLIRPass MLIRTransforms) + diff --git a/mlir/examples/toy/Ch3/include/toy/AST.h b/mlir/examples/toy/Ch3/include/toy/AST.h index 456a32309c40..2ad3392c11ac 100644 --- a/mlir/examples/toy/Ch3/include/toy/AST.h +++ b/mlir/examples/toy/Ch3/include/toy/AST.h @@ -33,10 +33,9 @@ namespace toy { -/// A variable +/// A variable type with shape information. struct VarType { - enum { TY_FLOAT, TY_INT } elt_ty; - std::vector shape; + std::vector shape; }; /// Base class for all expression nodes. @@ -50,9 +49,7 @@ public: Expr_Var, Expr_BinOp, Expr_Call, - Expr_Print, // builtin - Expr_If, - Expr_For, + Expr_Print, }; ExprAST(ExprASTKind kind, Location location) @@ -85,7 +82,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Num; } }; -/// +/// Expression class for a literal value. class LiteralExprAST : public ExprAST { std::vector> values; std::vector dims; @@ -116,7 +113,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_Var; } }; -/// +/// Expression class for defining a variable. class VarDeclExprAST : public ExprAST { std::string name; VarType type; @@ -136,7 +133,7 @@ public: static bool classof(const ExprAST *C) { return C->getKind() == Expr_VarDecl; } }; -/// +/// Expression class for a return operator. class ReturnExprAST : public ExprAST { llvm::Optional> expr; diff --git a/mlir/examples/toy/Ch3/include/toy/Dialect.h b/mlir/examples/toy/Ch3/include/toy/Dialect.h index a256379661b8..91dd631d2ffb 100644 --- a/mlir/examples/toy/Ch3/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch3/include/toy/Dialect.h @@ -16,7 +16,7 @@ // ============================================================================= // // This file implements the IR Dialect for the Toy language. -// See g3doc/Tutorials/Toy/Ch-3.md for more information. +// See g3doc/Tutorials/Toy/Ch-2.md for more information. // //===----------------------------------------------------------------------===// @@ -25,311 +25,29 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Function.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" namespace mlir { -class Builder; -} - namespace toy { /// This is the definition of the Toy dialect. A dialect inherits from -/// mlir::Dialect and register custom operations and types (in its constructor). -/// It can also overriding general behavior of dialects exposed as virtual -/// method, for example regarding verification and parsing/printing. +/// mlir::Dialect and registers custom attributes, operations, and types (in its +/// constructor). It can also override some general behavior exposed via virtual +/// methods. class ToyDialect : public mlir::Dialect { public: explicit ToyDialect(mlir::MLIRContext *ctx); - /// Parse a type registered to this dialect. Overriding this method is - /// required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - mlir::Type parseType(llvm::StringRef tyData, - mlir::Location loc) const override; - - /// Print a type registered to this dialect. Overriding this method is - /// only required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - void printType(mlir::Type type, llvm::raw_ostream &os) const override; + /// Provide a utility accessor to the dialect namespace. This is used by + /// several utilities for casting between dialects. + static llvm::StringRef getDialectNamespace() { return "toy"; } }; -//////////////////////////////////////////////////////////////////////////////// -/////////////////////// Custom Types for the Dialect /////////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -namespace detail { -struct ToyArrayTypeStorage; -} - -/// LLVM-style RTTI: one entry per subclass to allow dyn_cast/isa. -enum ToyTypeKind { - // The enum starts at the range reserved for this dialect. - TOY_TYPE = mlir::Type::FIRST_TOY_TYPE, - TOY_ARRAY, -}; - -/// Type for Toy arrays. -/// In MLIR Types are reference to immutable and uniqued objects owned by the -/// MLIRContext. As such `ToyArrayType` only wraps a pointer to an uniqued -/// instance of `ToyArrayTypeStorage` (defined in our implementation file) and -/// provides the public facade API to interact with the type. -class ToyArrayType : public mlir::Type::TypeBase { -public: - using Base::Base; - - /// Returns the dimensions for this array, or and empty range for a generic - /// array. - llvm::ArrayRef getShape(); - - /// Predicate to test if this array is generic (shape haven't been inferred - /// yet). - bool isGeneric() { return getShape().empty(); } - - /// Return the rank of this array (0 if it is generic). - int getRank() { return getShape().size(); } - - /// Return the type of individual elements in the array. - mlir::Type getElementType(); - - /// Get the unique instance of this Type from the context. - /// A ToyArrayType is only defined by the shape of the array. - static ToyArrayType get(mlir::MLIRContext *context, - llvm::ArrayRef shape = {}); - - /// Support method to enable LLVM-style RTTI type casting. - static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } -}; - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Constant operation turns a literal into an SSA value. The data is attached -/// to the operation as an attribute. For example: -/// -/// %0 = "toy.constant"() -/// {value: dense, [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>} -/// : () -> !toy.array<2, 3> -/// -/// An operation inherits from `class Op` and specifies optional traits. Here we -/// indicate that `toy.constant` does not have any operands and returns a single -/// result. The traits provide some utilities methods for the operation, for -/// instance we will be able to use `getResult()`, but `getOperand()` won't be -/// available. -class ConstantOp : public mlir::Op { -public: - /// This is the name used by MLIR to match an operation to this class during - /// parsing. - static llvm::StringRef getOperationName() { return "toy.constant"; } - - /// The operation can have extra verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populates the `state` that MLIR uses to create operations. - /// The `toy.constant` operation does not have arguments but attaches a - /// constant array as an attribute and returns it as an SSA value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::ArrayRef shape, - mlir::DenseElementsAttr value); - - /// Similar to the one above, but takes a single float and returns a - /// !toy.array<1>. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value); - - /// Inherit constructor. - using Op::Op; -}; - -/// Generic calls represent calls to a user defined function that needs to -/// be specialized for the shape of its arguments. The callee name is attached -/// as a literal string as an attribute. The arguments list must match the -/// arguments expected by the callee. For example: -/// -/// %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} -/// : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array"> -/// -/// This is only valid if a function named "my_func" exists and takes two -/// arguments. -class GenericCallOp - : public mlir::Op { -public: - /// MLIR will use this to register the operation with the parser/printer. - static llvm::StringRef getOperationName() { return "toy.generic_call"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to the builder to allow: - /// mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.generic_call` operation accepts a callee name and a list of - /// arguments for the call. - static void build(mlir::Builder *builder, mlir::OperationState &state, - llvm::StringRef callee, - llvm::ArrayRef arguments); - - /// Return the name of the callee. - llvm::StringRef getCalleeName(); - - /// Inherit constructor. - using Op::Op; -}; - -/// Return operations terminate blocks (and functions as well). They take a -/// single argument and the type must match the function return type. -class ReturnOp - : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.return"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.return` operation accepts an optional single array as an argument - /// and does not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value = nullptr); - - /// Return true if there is a returned value. - bool hasOperand() { return 0 != getNumOperands(); } - - /// Helper to return the optional operand. Caller must check if the operand - /// is present before calling this. - mlir::Value *getOperand() { return getOperation()->getOperand(0); } - - /// Inherit constructor. - using Op::Op; -}; - -/// The print builtin takes a single array argument and does not return any. -class PrintOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.print"; } - - /// Operations can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.print` operation accepts a single array as argument and does - /// not have any returned value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - /// Inherit constructor. - using Op::Op; -}; - -class TransposeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.transpose"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.transpose` operation accepts a single array as argument and - /// returns the transposed array as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value); - - /// Inherit constructor. - using Op::Op; -}; - -/// Reshape operation is transforming its input array into a new array with the -/// same number of elements but different shapes. For example: -/// -/// %0 = "toy.reshape"(%arg1) : (!toy.array<10>) -> !toy.array<5, 2> -/// -class ReshapeOp : public mlir::Op { -public: - static llvm::StringRef getOperationName() { return "toy.reshape"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.reshape` operation accepts a single array as argument and - /// returns the array with the specified reshapedType as its only result. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType); - - /// Inherit constructor. - using Op::Op; -}; - -/// Binary operation implementing a multiplication. For two-dimensional array -/// a matrix multiplication is implemented, while for one dimensional array a -/// dot product is performed. -class MulOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult> { -public: - static llvm::StringRef getOperationName() { return "toy.mul"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; - -/// Element wise addition of two arrays. The shape must match. -class AddOp : public mlir::Op::Impl, - mlir::OpTrait::OneResult> { -public: - static llvm::StringRef getOperationName() { return "toy.add"; } - - /// Operation can add custom verification beyond the traits they define. - mlir::LogicalResult verify(); - - /// Interface to mlir::Builder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.mul` operation accepts two operands as argument and returns - /// a single value. - static void build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs); - - /// Convenience accessor for LHS of the expression. - mlir::Value *getLHS() { return getOperand(0); } - - /// Convenience accessor for RHS of the expression. - mlir::Value *getRHS() { return getOperand(1); } - - /// Inherit constructor. - using Op::Op; -}; +/// Include the auto-generated header file containing the declarations of the +/// toy operations. +#define GET_OP_CLASSES +#include "toy/Ops.h.inc" } // end namespace toy +} // end namespace mlir #endif // MLIR_TUTORIAL_TOY_DIALECT_H_ diff --git a/mlir/examples/toy/Ch3/include/toy/Lexer.h b/mlir/examples/toy/Ch3/include/toy/Lexer.h index d73adb9706b7..21f92614912e 100644 --- a/mlir/examples/toy/Ch3/include/toy/Lexer.h +++ b/mlir/examples/toy/Ch3/include/toy/Lexer.h @@ -31,7 +31,7 @@ namespace toy { /// Structure definition a location in a file. struct Location { - std::shared_ptr file; ///< filename + std::shared_ptr file; ///< filename. int line; ///< line number. int col; ///< column number. }; diff --git a/mlir/examples/toy/Ch3/include/toy/Ops.td b/mlir/examples/toy/Ch3/include/toy/Ops.td new file mode 100644 index 000000000000..4d5c2f2cf1f6 --- /dev/null +++ b/mlir/examples/toy/Ch3/include/toy/Ops.td @@ -0,0 +1,247 @@ +//===- Ops.td - Toy dialect operation definitions ----------*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines the operations of the Toy dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOY_OPS +#else +#define TOY_OPS + +#ifdef OP_BASE +#else +include "mlir/IR/OpBase.td" +#endif // OP_BASE + +// Provide a definition of the 'toy' dialect in the ODS framework so that we +// can define our operations. +def Toy_Dialect : Dialect { + let name = "toy"; + let cppNamespace = "toy"; +} + +// Base class for toy dialect operations. This operation inherits from the base +// `Op` class in OpBase.td, and provides: +// * The parent dialect of the operation. +// * The mnemonic for the operation, or the name without the dialect prefix. +// * A list of traits for the operation. +class Toy_Op traits = []> : + Op; + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +// We define a toy operation by inherting from our base 'Toy_Op' class above. +// Here we provide the mnemonic and a list of traits for the operation. The +// constant operation is marked as 'NoSideEffect' as it is a pure operation +// and may be removed if dead. +def ConstantOp : Toy_Op<"constant", [NoSideEffect]> { + // Provide a summary and description for this operation. This can be used to + // auto-generate documenatation of the operations within our dialect. + let summary = "constant"; + let description = [{ + Constant operation turns a literal into an SSA value. The data is attached + to the operation as an attribute. For example: + + ```mlir + %0 = "toy.constant"() + { value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> } + : () -> tensor<2x3xf64> + ``` + }]; + + // The constant operation takes an attribute as the only input. + let arguments = (ins F64ElementsAttr:$value); + + // The constant operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the constant operation. These method populates + // the `state` that MLIR uses to create operations, i.e. these are used when + // using `builder.create(...)`. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "DenseElementsAttr value", [{ + build(builder, result, value.getType(), value); + }]>, + + // Build a constant with a given constant floating-point value. + OpBuilder<"Builder *builder, OperationState &result, double value", [{ + buildConstantOp(builder, result, value); + }]> + ]; + + // Invoke a static verify method to verify this constant operation. + let verifier = [{ return ::verify(*this); }]; +} + +def AddOp : Toy_Op<"add", [NoSideEffect]> { + let summary = "element-wise addition operation"; + let description = [{ + The "add" operation performs element-wise addition between two tensors. + The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building an AddOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildAddOp(b, result, lhs, rhs); + }] + >]; +} + +def GenericCallOp : Toy_Op<"generic_call"> { + let summary = "generic call operation"; + let description = [{ + Generic calls represent calls to a user defined function that needs to + be specialized for the shape of its arguments. The callee name is attached + as a symbol reference via an attribute. The arguments list must match the + arguments expected by the callee. For example: + + ```mlir + %4 = "toy.generic_call"(%1, %3) {callee = @my_func} + : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> + ``` + + This is only valid if a function named "my_func" exists and takes two + arguments. + }]; + + // The generic call operation takes a symbol reference attribute as the + // callee, and inputs for the call. + let arguments = (ins SymbolRefAttr:$callee, Variadic:$inputs); + + // The generic call operation returns a single value of TensorType. + let results = (outs F64Tensor); + + // Add custom build methods for the generic call operation. + let builders = [ + // Build a constant with a given constant tensor value. + OpBuilder<"Builder *builder, OperationState &result, " + "StringRef callee, ArrayRef arguments", [{ + buildGenericCallOp(builder, result, callee, arguments); + }]> + ]; +} + +def MulOp : Toy_Op<"mul", [NoSideEffect]> { + let summary = "element-wise multiplication operation"; + let description = [{ + The "mul" operation performs element-wise multiplication between two + tensors. The shapes of the tensor operands are expected to match. + }]; + + let arguments = (ins F64Tensor:$lhs, F64Tensor:$rhs); + let results = (outs F64Tensor); + + // Allow building a MulOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *lhs, Value *rhs", [{ + buildMulOp(b, result, lhs, rhs); + }] + >]; +} + +def PrintOp : Toy_Op<"print"> { + let summary = "print operation"; + let description = [{ + The "print" builtin operation prints a given input tensor, and produces + no results. + }]; + + // The print operation takes an input tensor to print. + let arguments = (ins F64Tensor:$input); +} + +def ReshapeOp : Toy_Op<"reshape", [NoSideEffect]> { + let summary = "tensor reshape operation"; + let description = [{ + Reshape operation is transforming its input tensor into a new tensor with + the same number of elements but different shapes. For example: + + ```mlir + %0 = "toy.reshape"(%arg1) : (tensor<10xf64>) -> tensor<5x2xf64> + ``` + }]; + + let arguments = (ins F64Tensor:$input); + + // Enabled registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // We expect that the reshape operation returns a statically shaped tensor. + let results = (outs StaticShapeTensorOf<[F64]>); +} + +def ReturnOp : Toy_Op<"return", [Terminator, HasParent<"FuncOp">]> { + let summary = "return operation"; + let description = [{ + The "return" operation represents a return operation within a function. + The operation takes an optional tensor operand and produces no results. + The operand type must match the signature of the function that contains + the operation. For example: + + ```mlir + func @foo() -> tensor<2xf64> { + ... + toy.return %0 : tensor<2xf64> + } + ``` + }]; + + // The return operation takes an optional input operand to return. This + // value must match the return type of the enclosing function. + let arguments = (ins Variadic:$input); + + // Allow building a ReturnOp with no return operand. + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + // Provide extra utility definitions on the c++ operation class definition. + let extraClassDeclaration = [{ + bool hasOperand() { return getNumOperands() != 0; } + }]; + + // Invoke a static verify method to verify this return operation. + let verifier = [{ return ::verify(*this); }]; +} + +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> { + let summary = "transpose operation"; + + let arguments = (ins F64Tensor:$input); + let results = (outs F64Tensor); + + // Enabled registering canonicalization patterns with this operation. + let hasCanonicalizer = 1; + + // Allow building a TransposeOp with from the two input operands. + let builders = [ + OpBuilder<"Builder *b, OperationState &result, Value *input", [{ + buildTransposeOp(b, result, input); + }] + >]; +} + +#endif // TOY_OPS diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp new file mode 100644 index 000000000000..375533b880c8 --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -0,0 +1,151 @@ +//===- Dialect.cpp - Toy IR Dialect registration in MLIR ------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements the dialect for the Toy IR: custom type parsing and +// operation verification. +// +//===----------------------------------------------------------------------===// + +#include "toy/Dialect.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; +using namespace mlir::toy; + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Toy Operations +//===----------------------------------------------------------------------===// + +/// Build a constant operation. +/// The builder is passed as an argument, so is the state that this method is +/// expected to fill in order to build the operation. +static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state, + double value) { + auto dataType = builder->getTensorType({}, builder->getF64Type()); + auto dataAttribute = DenseElementsAttr::get(dataType, value); + ConstantOp::build(builder, state, dataType, dataAttribute); +} + +/// Verifier for constant operation. +static mlir::LogicalResult verify(ConstantOp op) { + // If the return type of the constant is not an unranked tensor, the shape + // must match the shape of the attribute holding the data. + auto resultType = op.getResult()->getType().cast(); + if (!resultType) + return success(); + + auto attrType = op.value().getType().cast(); + if (attrType.getRank() != resultType.getRank()) { + return op.emitOpError( + "return type must match the one of the attached value " + "attribute: ") + << attrType.getRank() << " != " << resultType.getRank(); + } + for (int dim = 0; dim < attrType.getRank(); ++dim) { + if (attrType.getShape()[dim] != resultType.getShape()[dim]) { + return op.emitOpError( + "return type shape mismatches its attribute at dimension ") + << dim << ": " << attrType.getShape()[dim] + << " != " << resultType.getShape()[dim]; + } + } + return mlir::success(); +} + +static void buildAddOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +static void buildGenericCallOp(mlir::Builder *builder, + mlir::OperationState &state, StringRef callee, + ArrayRef arguments) { + // Generic call always returns an unranked Tensor initially. + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(arguments); + state.addAttribute("callee", builder->getSymbolRefAttr(callee)); +} + +static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state, + mlir::Value *lhs, mlir::Value *rhs) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands({lhs, rhs}); +} + +static mlir::LogicalResult verify(ReturnOp op) { + // We know that the parent operation is a function, because of the 'HasParent' + // trait attached to the operation definition. + auto function = cast(op.getParentOp()); + + /// ReturnOps can only have a single optional operand. + if (op.getNumOperands() > 1) + return op.emitOpError() << "expects at most 1 return operand"; + + // The operand number and types must match the function signature. + const auto &results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError() + << "does not return the same number of values (" + << op.getNumOperands() << ") as the enclosing function (" + << results.size() << ")"; + + // If the operation does not have an input, we are done. + if (!op.hasOperand()) + return mlir::success(); + + auto inputType = *op.operand_type_begin(); + auto resultType = results.front(); + + // Check that the result type of the function matches the operand type. + if (inputType == resultType || inputType.isa() || + resultType.isa()) + return mlir::success(); + + return op.emitError() << "type of return operand (" + << *op.operand_type_begin() + << ") doesn't match function result type (" + << results.front() << ")"; +} + +static void buildTransposeOp(mlir::Builder *builder, + mlir::OperationState &state, mlir::Value *value) { + state.addTypes(builder->getTensorType(builder->getF64Type())); + state.addOperands(value); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "toy/Ops.cpp.inc" diff --git a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp index 9bd3fa68d11e..5f12d0a8798a 100644 --- a/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch3/mlir/MLIRGen.cpp @@ -25,30 +25,30 @@ #include "toy/Dialect.h" #include "mlir/Analysis/Verifier.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" -#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/raw_ostream.h" #include +using namespace mlir::toy; using namespace toy; + +using llvm::ArrayRef; using llvm::cast; using llvm::dyn_cast; using llvm::isa; +using llvm::makeArrayRef; using llvm::ScopedHashTableScope; using llvm::SmallVector; using llvm::StringRef; using llvm::Twine; -using std::make_unique; namespace { @@ -57,56 +57,43 @@ namespace { /// This will emit operations that are specific to the Toy language, preserving /// the semantics of the language and (hopefully) allow to perform accurate /// analysis and transformation based on these high level semantics. -/// -/// At this point we take advantage of the "raw" MLIR APIs to create operations -/// that haven't been registered in any way with MLIR. These operations are -/// unknown to MLIR, custom passes could operate by string-matching the name of -/// these operations, but no other type checking or semantic is associated with -/// them natively by MLIR. class MLIRGenImpl { public: - MLIRGenImpl(mlir::MLIRContext &context) : context(context) {} + MLIRGenImpl(mlir::MLIRContext &context) : builder(&context) {} /// Public API: convert the AST for a Toy module (source file) to an MLIR - /// Module. - mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) { + /// Module operation. + mlir::ModuleOp mlirGen(ModuleAST &moduleAST) { // We create an empty MLIR module and codegen functions one at a time and // add them to the module. - theModule = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + theModule = mlir::ModuleOp::create(builder.getUnknownLoc()); for (FunctionAST &F : moduleAST) { auto func = mlirGen(F); if (!func) return nullptr; - theModule->push_back(func); + theModule.push_back(func); } - // FIXME: (in the next chapter...) without registering a dialect in MLIR, - // this won't do much, but it should at least check some structural - // properties. - if (failed(mlir::verify(*theModule))) { - emitError(mlir::UnknownLoc::get(&context), "module verification error"); + // Verify the module after we have finished constructing it, this will check + // the structural properties of the IR and invoke any specific verifiers we + // have on the Toy operations. + if (failed(mlir::verify(theModule))) { + theModule.emitError("module verification error"); return nullptr; } - return std::move(theModule); + return theModule; } private: - /// In MLIR (like in LLVM) a "context" object holds the memory allocation and - /// the ownership of many internal structure of the IR and provide a level - /// of "uniquing" across multiple modules (types for instance). - mlir::MLIRContext &context; + /// A "module" matches a Toy source file: containing a list of functions. + mlir::ModuleOp theModule; - /// A "module" matches a source file: it contains a list of functions. - mlir::OwningModuleRef theModule; - - /// The builder is a helper class to create IR inside a function. It is - /// re-initialized every time we enter a function and kept around as a - /// convenience for emitting individual operations. - /// The builder is stateful, in particular it keeeps an "insertion point": - /// this is where the next operations will be introduced. - std::unique_ptr builder; + /// The builder is a helper class to create IR inside a function. The builder + /// is stateful, in particular it keeeps an "insertion point": this is where + /// the next operations will be introduced. + mlir::OpBuilder builder; /// The symbol table maps a variable name to a value in the current scope. /// Entering a function creates a new scope, and the function arguments are @@ -116,37 +103,35 @@ private: /// Helper conversion for a Toy AST location to an MLIR location. mlir::Location loc(Location loc) { - return mlir::FileLineColLoc::get(mlir::Identifier::get(*loc.file, &context), - loc.line, loc.col, &context); + return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line, + loc.col); } - /// Declare a variable in the current scope, return true if the variable + /// Declare a variable in the current scope, return success if the variable /// wasn't declared yet. - bool declare(llvm::StringRef var, mlir::Value *value) { - if (symbolTable.count(var)) { - return false; - } + mlir::LogicalResult declare(llvm::StringRef var, mlir::Value *value) { + if (symbolTable.count(var)) + return mlir::failure(); symbolTable.insert(var, value); - return true; + return mlir::success(); } /// Create the prototype for an MLIR function with as many arguments as the /// provided Toy AST prototype. mlir::FuncOp mlirGen(PrototypeAST &proto) { + auto location = loc(proto.loc()); + // This is a generic function, the return type will be inferred later. - llvm::SmallVector ret_types; - // Arguments type is uniformly a generic array. + // Arguments type are uniformly unranked tensors. llvm::SmallVector arg_types(proto.getArgs().size(), getType(VarType{})); - auto func_type = mlir::FunctionType::get(arg_types, ret_types, &context); - auto function = mlir::FuncOp::create(loc(proto.loc()), proto.getName(), - func_type, /* attrs = */ {}); + auto func_type = builder.getFunctionType(arg_types, llvm::None); + auto function = mlir::FuncOp::create(location, proto.getName(), func_type); // Mark the function as generic: it'll require type specialization for every // call site. if (function.getNumArguments()) - function.setAttr("toy.generic", mlir::BoolAttr::get(true, &context)); - + function.setAttr("toy.generic", builder.getUnitAttr()); return function; } @@ -165,18 +150,22 @@ private: // argument list as the function itself. auto &entryBlock = *function.addEntryBlock(); auto &protoArgs = funcAST.getProto()->getArgs(); + // Declare all the function arguments in the symbol table. for (const auto &name_value : llvm::zip(protoArgs, entryBlock.getArguments())) { - declare(std::get<0>(name_value)->getName(), std::get<1>(name_value)); + if (failed(declare(std::get<0>(name_value)->getName(), + std::get<1>(name_value)))) + return nullptr; } - // Create a builder for the function, it will be used throughout the codegen - // to create operations in this function. - builder = std::make_unique(function.getBody()); + // Set the insertion point in the builder to the beginning of the function + // body, it will be used throughout the codegen to create operations in this + // function. + builder.setInsertionPointToStart(&entryBlock); // Emit the body of the function. - if (!mlirGen(*funcAST.getBody())) { + if (mlir::failed(mlirGen(*funcAST.getBody()))) { function.erase(); return nullptr; } @@ -184,10 +173,16 @@ private: // Implicitly return void if no return statement was emitted. // FIXME: we may fix the parser instead to always return the last expression // (this would possibly help the REPL case later) - if (function.getBlocks().back().back().getName().getStringRef() != - "toy.return") { - ReturnExprAST fakeRet(funcAST.getProto()->loc(), llvm::None); - mlirGen(fakeRet); + ReturnOp returnOp; + if (!entryBlock.empty()) + returnOp = dyn_cast(entryBlock.back()); + if (!returnOp) { + builder.create(loc(funcAST.getProto()->loc())); + } else if (returnOp.hasOperand()) { + // Otherwise, if this return operation has an operand then add a result to + // the function. + function.setType(builder.getFunctionType(function.getType().getInputs(), + getType(VarType{}))); } return function; @@ -206,11 +201,11 @@ private: // and the result value is returned. If an error occurs we get a nullptr // and propagate. // - mlir::Value *L = mlirGen(*binop.getLHS()); - if (!L) + mlir::Value *lhs = mlirGen(*binop.getLHS()); + if (!lhs) return nullptr; - mlir::Value *R = mlirGen(*binop.getRHS()); - if (!R) + mlir::Value *rhs = mlirGen(*binop.getRHS()); + if (!rhs) return nullptr; auto location = loc(binop.loc()); @@ -218,123 +213,113 @@ private: // support '+' and '*'. switch (binop.getOp()) { case '+': - return builder->create(location, L, R).getResult(); - break; + return builder.create(location, lhs, rhs); case '*': - return builder->create(location, L, R).getResult(); - default: - emitError(loc(binop.loc()), "error: invalid binary operator '") - << binop.getOp() << "'"; - return nullptr; + return builder.create(location, lhs, rhs); } + + emitError(location, "invalid binary operator '") << binop.getOp() << "'"; + return nullptr; } - // This is a reference to a variable in an expression. The variable is - // expected to have been declared and so should have a value in the symbol - // table, otherwise emit an error and return nullptr. + /// This is a reference to a variable in an expression. The variable is + /// expected to have been declared and so should have a value in the symbol + /// table, otherwise emit an error and return nullptr. mlir::Value *mlirGen(VariableExprAST &expr) { - if (symbolTable.count(expr.getName())) - return symbolTable.lookup(expr.getName()); + if (auto *variable = symbolTable.lookup(expr.getName())) + return variable; + emitError(loc(expr.loc()), "error: unknown variable '") << expr.getName() << "'"; return nullptr; } - // Emit a return operation, return true on success. - bool mlirGen(ReturnExprAST &ret) { + /// Emit a return operation. This will return failure if any generation fails. + mlir::LogicalResult mlirGen(ReturnExprAST &ret) { auto location = loc(ret.loc()); - // `return` takes an optional expression, we need to account for it here. - if (!ret.getExpr().hasValue()) { - builder->create(location); - return true; + + // 'return' takes an optional expression, handle that case here. + mlir::Value *expr = nullptr; + if (ret.getExpr().hasValue()) { + if (!(expr = mlirGen(*ret.getExpr().getValue()))) + return mlir::failure(); } - auto *expr = mlirGen(*ret.getExpr().getValue()); - if (!expr) - return false; - builder->create(location, expr); - return true; + + // Otherwise, this return operation has zero operands. + builder.create(location, expr ? makeArrayRef(expr) + : ArrayRef()); + return mlir::success(); } - // Emit a literal/constant array. It will be emitted as a flattened array of - // data in an Attribute attached to a `toy.constant` operation. - // See documentation on [Attributes](LangRef.md#attributes) for more details. - // Here is an excerpt: - // - // Attributes are the mechanism for specifying constant data in MLIR in - // places where a variable is never allowed [...]. They consist of a name - // and a [concrete attribute value](#attribute-values). It is possible to - // attach attributes to operations, functions, and function arguments. The - // set of expected attributes, their structure, and their interpretation - // are all contextually dependent on what they are attached to. - // - // Example, the source level statement: - // var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; - // will be converted to: - // %0 = "toy.constant"() {value: dense, - // [[1.000000e+00, 2.000000e+00, 3.000000e+00], - // [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> memref<2x3xf64> - // + /// Emit a literal/constant array. It will be emitted as a flattened array of + /// data in an Attribute attached to a `toy.constant` operation. + /// See documentation on [Attributes](LangRef.md#attributes) for more details. + /// Here is an excerpt: + /// + /// Attributes are the mechanism for specifying constant data in MLIR in + /// places where a variable is never allowed [...]. They consist of a name + /// and a concrete attribute value. The set of expected attributes, their + /// structure, and their interpretation are all contextually dependent on + /// what they are attached to. + /// + /// Example, the source level statement: + /// var a<2, 3> = [[1, 2, 3], [4, 5, 6]]; + /// will be converted to: + /// %0 = "toy.constant"() {value: dense, + /// [[1.000000e+00, 2.000000e+00, 3.000000e+00], + /// [4.000000e+00, 5.000000e+00, 6.000000e+00]]>} : () -> tensor<2x3xf64> + /// mlir::Value *mlirGen(LiteralExprAST &lit) { - auto location = loc(lit.loc()); - // The attribute is a vector with an attribute per element (number) in the - // array, see `collectData()` below for more details. - std::vector data; + auto type = getType(lit.getDims()); + + // The attribute is a vector with a floating point value per element + // (number) in the array, see `collectData()` below for more details. + std::vector data; data.reserve(std::accumulate(lit.getDims().begin(), lit.getDims().end(), 1, std::multiplies())); collectData(lit, data); - // FIXME: using a tensor type is a HACK here. - // Can we do differently without registering a dialect? Using a string blob? - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto dataType = builder->getTensorType(lit.getDims(), elementType); + // The type of this attribute is tensor of 64-bit floating-point with the + // shape of the literal. + mlir::Type elementType = builder.getF64Type(); + auto dataType = builder.getTensorType(lit.getDims(), elementType); - // This is the actual attribute that actually hold the list of values for - // this array literal. - auto dataAttribute = builder->getDenseElementsAttr(dataType, data) - .cast(); + // This is the actual attribute that holds the list of values for this + // tensor literal. + auto dataAttribute = + mlir::DenseElementsAttr::get(dataType, llvm::makeArrayRef(data)); - // Build the MLIR op `toy.constant`, only boilerplate below. - return builder->create(location, lit.getDims(), dataAttribute) - .getResult(); + // Build the MLIR op `toy.constant`. This invokes the `ConstantOp::build` + // method. + return builder.create(loc(lit.loc()), type, dataAttribute); } - // Recursive helper function to accumulate the data that compose an array - // literal. It flattens the nested structure in the supplied vector. For - // example with this array: - // [[1, 2], [3, 4]] - // we will generate: - // [ 1, 2, 3, 4 ] - // Individual numbers are wrapped in a light wrapper `mlir::FloatAttr`. - // Attributes are the way MLIR attaches constant to operations and functions. - void collectData(ExprAST &expr, std::vector &data) { + /// Recursive helper function to accumulate the data that compose an array + /// literal. It flattens the nested structure in the supplied vector. For + /// example with this array: + /// [[1, 2], [3, 4]] + /// we will generate: + /// [ 1, 2, 3, 4 ] + /// Individual numbers are represented as doubles. + /// Attributes are the way MLIR attaches constant to operations. + void collectData(ExprAST &expr, std::vector &data) { if (auto *lit = dyn_cast(&expr)) { for (auto &value : lit->getValues()) collectData(*value, data); return; } + assert(isa(expr) && "expected literal or number expr"); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked( - elementType, cast(expr).getValue(), loc(expr.loc())); - data.push_back(attr); + data.push_back(cast(expr).getValue()); } - // Emit a call expression. It emits specific operations for the `transpose` - // builtin. Other identifiers are assumed to be user-defined functions. + /// Emit a call expression. It emits specific operations for the `transpose` + /// builtin. Other identifiers are assumed to be user-defined functions. mlir::Value *mlirGen(CallExprAST &call) { + llvm::StringRef callee = call.getCallee(); auto location = loc(call.loc()); - std::string callee = call.getCallee(); - if (callee == "transpose") { - if (call.getArgs().size() != 1) { - emitError(location, "MLIR codegen encountered an error: toy.transpose " - "does not accept multiple arguments"); - return nullptr; - } - mlir::Value *arg = mlirGen(*call.getArgs()[0]); - return builder->create(location, arg).getResult(); - } - // Codegen the operands first + // Codegen the operands first. SmallVector operands; for (auto &expr : call.getArgs()) { auto *arg = mlirGen(*expr); @@ -342,34 +327,41 @@ private: return nullptr; operands.push_back(arg); } - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - return builder->create(location, call.getCallee(), operands) - .getResult(); + + // Builting calls have their custom operation, meaning this is a + // straightforward emission. + if (callee == "transpose") { + if (call.getArgs().size() != 1) { + emitError(location, "MLIR codegen encountered an error: toy.transpose " + "does not accept multiple arguments"); + return nullptr; + } + return builder.create(location, operands[0]); + } + + // Otherwise this is a call to a user-defined function. Calls to ser-defined + // functions are mapped to a custom call that takes the callee name as an + // attribute. + return builder.create(location, callee, operands); } - // Emit a call expression. It emits specific operations for two builtins: - // transpose(x) and print(x). Other identifiers are assumed to be user-defined - // functions. Return false on failure. - bool mlirGen(PrintExprAST &call) { + /// Emit a print expression. It emits specific operations for two builtins: + /// transpose(x) and print(x). + mlir::LogicalResult mlirGen(PrintExprAST &call) { auto *arg = mlirGen(*call.getArg()); if (!arg) - return false; - auto location = loc(call.loc()); - builder->create(location, arg); - return true; + return mlir::failure(); + + builder.create(loc(call.loc()), arg); + return mlir::success(); } - // Emit a constant for a single number (FIXME: semantic? broadcast?) + /// Emit a constant for a single number (FIXME: semantic? broadcast?) mlir::Value *mlirGen(NumberExprAST &num) { - auto location = loc(num.loc()); - mlir::Type elementType = mlir::FloatType::getF64(&context); - auto attr = mlir::FloatAttr::getChecked(elementType, num.getValue(), - loc(num.loc())); - return builder->create(location, attr).getResult(); + return builder.create(loc(num.loc()), num.getValue()); } - // Dispatch codegen for the right expression subclass using RTTI. + /// Dispatch codegen for the right expression subclass using RTTI. mlir::Value *mlirGen(ExprAST &expr) { switch (expr.getKind()) { case toy::ExprAST::Expr_BinOp: @@ -390,77 +382,75 @@ private: } } - // Handle a variable declaration, we'll codegen the expression that forms the - // initializer and record the value in the symbol table before returning it. - // Future expressions will be able to reference this variable through symbol - // table lookup. + /// Handle a variable declaration, we'll codegen the expression that forms the + /// initializer and record the value in the symbol table before returning it. + /// Future expressions will be able to reference this variable through symbol + /// table lookup. mlir::Value *mlirGen(VarDeclExprAST &vardecl) { - mlir::Value *value = nullptr; - auto location = loc(vardecl.loc()); - if (auto init = vardecl.getInitVal()) { - value = mlirGen(*init); - if (!value) - return nullptr; - // We have the initializer value, but in case the variable was declared - // with specific shape, we emit a "reshape" operation. It will get - // optimized out later as needed. - if (!vardecl.getType().shape.empty()) { - value = builder - ->create( - location, value, - getType(vardecl.getType()).cast()) - .getResult(); - } - } else { + auto init = vardecl.getInitVal(); + if (!init) { emitError(loc(vardecl.loc()), "missing initializer in variable declaration"); return nullptr; } - // Register the value in the symbol table - declare(vardecl.getName(), value); + + mlir::Value *value = mlirGen(*init); + if (!value) + return nullptr; + + // We have the initializer value, but in case the variable was declared + // with specific shape, we emit a "reshape" operation. It will get + // optimized out later as needed. + if (!vardecl.getType().shape.empty()) { + value = builder.create(loc(vardecl.loc()), + getType(vardecl.getType()), value); + } + + // Register the value in the symbol table. + if (failed(declare(vardecl.getName(), value))) + return nullptr; return value; } - /// Codegen a list of expression, return false if one of them hit an error. - bool mlirGen(ExprASTList &blockAST) { - ScopedHashTableScope var_scope(symbolTable); + /// Codegen a list of expression, return failure if one of them hit an error. + mlir::LogicalResult mlirGen(ExprASTList &blockAST) { + ScopedHashTableScope var_scope(symbolTable); for (auto &expr : blockAST) { // Specific handling for variable declarations, return statement, and // print. These can only appear in block list and not in nested // expressions. if (auto *vardecl = dyn_cast(expr.get())) { if (!mlirGen(*vardecl)) - return false; + return mlir::failure(); continue; } - if (auto *ret = dyn_cast(expr.get())) { - if (!mlirGen(*ret)) - return false; - return true; - } + if (auto *ret = dyn_cast(expr.get())) + return mlirGen(*ret); if (auto *print = dyn_cast(expr.get())) { - if (!mlirGen(*print)) - return false; + if (mlir::failed(mlirGen(*print))) + return mlir::success(); continue; } + // Generic expression dispatch codegen. if (!mlirGen(*expr)) - return false; + return mlir::failure(); } - return true; + return mlir::success(); } - /// Build a type from a list of shape dimensions. Types are `array` followed - /// by an optional dimension list, example: array<2, 2> - /// They are wrapped in a `toy` dialect (see next chapter) and get printed: - /// !toy.array<2, 2> - template mlir::Type getType(T shape) { - SmallVector shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); + /// Build a tensor type from a list of shape dimensions. + mlir::Type getType(ArrayRef shape) { + // If the shape is empty, then this type is unranked. + if (shape.empty()) + return builder.getTensorType(builder.getF64Type()); + + // Otherwise, we use the given shape. + return builder.getTensorType(shape, builder.getF64Type()); } - /// Build an MLIR type from a Toy AST variable type - /// (forward to the generic getType(T) above). + /// Build an MLIR type from a Toy AST variable type (forward to the generic + /// getType above). mlir::Type getType(const VarType &type) { return getType(type.shape); } }; diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp new file mode 100644 index 000000000000..9eb152ccccb6 --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -0,0 +1,75 @@ +//===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple combiner for optimizing pattern in the Toy +// dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "toy/Dialect.h" +#include +using namespace mlir; +using namespace toy; + +namespace { +/// Include the patterns defined in the Declarative Rewrite framework. +#include "ToyCombine.inc" +} // end anonymous namespace + +/// Fold transpose(transpose(x) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} + + /// This method attempts to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value *transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); + + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } +}; + +/// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +/// Register our patterns for rewrite by the Canonicalization framework. +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.td b/mlir/examples/toy/Ch3/mlir/ToyCombine.td new file mode 100644 index 000000000000..1a7e464023ae --- /dev/null +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.td @@ -0,0 +1,72 @@ +//===- ToyCombine.td - Pattern Match Optimizations for Toy -*- tablegen -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// Defines language-specific pattern match optimizations for Toy using +// Declarative Rewrite Rules (DRR) specified using TableGen records. +// +//===----------------------------------------------------------------------===// + +#ifndef TOY_COMBINE +#define TOY_COMBINE + +#ifndef OP_BASE +include "toy/Ops.td" +#endif // OP_BASE + +/* Pattern-Match and Rewrite using DRR: +class Pattern< + dag sourcePattern, list resultPatterns, + list additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)>; +*/ + +//===----------------------------------------------------------------------===// +// Basic Pattern-Match and Rewrite +//===----------------------------------------------------------------------===// + +// Reshape(Reshape(x)) = x +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite using Native Code Call +//===----------------------------------------------------------------------===// + +// Native Code Calls may be used for more complex transformations using inline +// C++ and C++ helper functions. + +// Reshape(Constant(x)) = x' +def ReshapeConstant : + NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; + +//===----------------------------------------------------------------------===// +// Pattern-Match and Rewrite with Constraints +//===----------------------------------------------------------------------===// + +// DRR allows for constraint checking when the transformation is conditional +// on operand properties. + +// Reshape(x) = x, where input and output shapes are identical +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; + +#endif // TOY_COMBINE diff --git a/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp b/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp deleted file mode 100644 index ad1ad114e83c..000000000000 --- a/mlir/examples/toy/Ch3/mlir/ToyDialect.cpp +++ /dev/null @@ -1,390 +0,0 @@ -//===- ToyDialect.cpp - Toy IR Dialect registration in MLIR ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file implements the dialect for the Toy IR: custom type parsing and -// operation verification. -// -//===----------------------------------------------------------------------===// - -#include "toy/Dialect.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/iterator_range.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/Regex.h" -#include "llvm/Support/raw_ostream.h" - -using llvm::ArrayRef; -using llvm::raw_ostream; -using llvm::raw_string_ostream; -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -namespace toy { -namespace detail { - -/// This class holds the implementation of the ToyArrayType. -/// It is intended to be uniqued based on its content and owned by the context. -struct ToyArrayTypeStorage : public mlir::TypeStorage { - /// This defines how we unique this type in the context: our key contains - /// only the shape, a more complex type would have multiple entries in the - /// tuple here. - /// The element of the tuples usually matches 1-1 the arguments from the - /// public `get()` method arguments from the facade. - using KeyTy = std::tuple>; - static unsigned hashKey(const KeyTy &key) { - return llvm::hash_combine(std::get<0>(key)); - } - /// When the key hash hits an existing type, we compare the shape themselves - /// to confirm we have the right type. - bool operator==(const KeyTy &key) const { return key == KeyTy(getShape()); } - - /// This is a factory method to create our type storage. It is only - /// invoked after looking up the type in the context using the key and not - /// finding it. - static ToyArrayTypeStorage *construct(mlir::TypeStorageAllocator &allocator, - const KeyTy &key) { - // Copy the shape array into the bumpptr allocator owned by the context. - ArrayRef shape = allocator.copyInto(std::get<0>(key)); - - // Allocate the instance for the ToyArrayTypeStorage itself - auto *storage = allocator.allocate(); - // Initialize the instance using placement new. - return new (storage) ToyArrayTypeStorage(shape); - } - - ArrayRef getShape() const { return shape; } - -private: - ArrayRef shape; - - /// Constructor is only invoked from the `construct()` method above. - ToyArrayTypeStorage(ArrayRef shape) : shape(shape) {} -}; - -} // namespace detail - -mlir::Type ToyArrayType::getElementType() { - return mlir::FloatType::getF64(getContext()); -} - -ToyArrayType ToyArrayType::get(mlir::MLIRContext *context, - ArrayRef shape) { - return Base::get(context, ToyTypeKind::TOY_ARRAY, shape); -} - -ArrayRef ToyArrayType::getShape() { return getImpl()->getShape(); } - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { - addOperations(); - addTypes(); -} - -/// Parse a type registered to this dialect, we expect only Toy arrays. -mlir::Type ToyDialect::parseType(StringRef tyData, mlir::Location loc) const { - // Sanity check: we only support array or array<...> - if (!tyData.startswith("array")) { - emitError(loc, "invalid Toy type '" + tyData + "', array expected"); - return nullptr; - } - // Drop the "array" prefix from the type name, we expect either an empty - // string or just the shape. - tyData = tyData.drop_front(StringRef("array").size()); - // This is the generic array case without shape, early return it. - if (tyData.empty()) - return ToyArrayType::get(getContext()); - - // Use a regex to parse the shape (for efficient we should store this regex in - // the dialect itself). - SmallVector matches; - auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); - if (!shapeRegex.match(tyData, &matches)) { - emitError(loc, "invalid toy array shape '" + tyData + "'"); - return nullptr; - } - SmallVector shape; - // Iterate through the captures, skip the first one which is the full string. - for (auto dimStr : - llvm::make_range(std::next(matches.begin()), matches.end())) { - if (dimStr.startswith(",")) - continue; // POSIX misses non-capturing groups. - if (dimStr.empty()) - continue; // '*' makes it an optional group capture - // Convert the capture to an integer - unsigned long long dim; - if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { - emitError(loc, "couldn't parse dimension as integer, matched: " + dimStr); - return mlir::Type(); - } - shape.push_back(dim); - } - // Finally we collected all the dimensions in the shape, - // create the array type. - return ToyArrayType::get(getContext(), shape); -} - -/// Print a Toy array type, for example `array<2, 3, 4>` -void ToyDialect::printType(mlir::Type type, raw_ostream &os) const { - auto arrayTy = type.dyn_cast(); - if (!arrayTy) { - os << "unknown toy type"; - return; - } - os << "array"; - if (!arrayTy.getShape().empty()) { - os << "<"; - mlir::interleaveComma(arrayTy.getShape(), os); - os << ">"; - } -} - -//////////////////////////////////////////////////////////////////////////////// -//////////////////// Custom Operations for the Dialect ///////////////////////// -//////////////////////////////////////////////////////////////////////////////// - -/// Helper to verify that the result of an operation is a Toy array type. -template static mlir::LogicalResult verifyToyReturnArray(T *op) { - if (!op->getResult()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getResult()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Helper to verify that the two operands of a binary operation are Toy -/// arrays.. -template static mlir::LogicalResult verifyToyBinOperands(T *op) { - if (!op->getOperand(0)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - if (!op->getOperand(1)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its LHS, got " - << op->getOperand(0)->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - ArrayRef shape, mlir::DenseElementsAttr value) { - state.types.push_back(ToyArrayType::get(builder->getContext(), shape)); - auto dataAttribute = builder->getNamedAttr("value", value); - state.attributes.push_back(dataAttribute); -} - -/// Build a constant operation. -/// The builder is passed as an argument, so is the state that this method is -/// expected to fill in order to build the operation. -void ConstantOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::FloatAttr value) { - // Broadcast and forward to the other build factory - mlir::Type elementType = mlir::FloatType::getF64(builder->getContext()); - auto dataType = builder->getTensorType({1}, elementType); - auto dataAttribute = builder->getDenseElementsAttr(dataType, {value}) - .cast(); - - ConstantOp::build(builder, state, {1}, dataAttribute); -} - -/// Verifier for constant operation. -mlir::LogicalResult ConstantOp::verify() { - // Ensure that the return type is a Toy array - if (failed(verifyToyReturnArray(this))) - return mlir::failure(); - - // We expect the constant itself to be stored as an attribute. - auto dataAttr = getAttr("value").dyn_cast(); - if (!dataAttr) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - auto attrType = dataAttr.getType().dyn_cast(); - if (!attrType) { - return emitOpError( - "missing valid `value` DenseElementsAttribute on toy.constant()"); - } - - // If the return type of the constant is not a generic array, the shape must - // match the shape of the attribute holding the data. - auto resultType = getResult()->getType().cast(); - if (!resultType.isGeneric()) { - if (attrType.getRank() != resultType.getRank()) { - return emitOpError("The rank of the toy.constant return type must match " - "the one of the attached value attribute: " + - Twine(attrType.getRank()) + - " != " + Twine(resultType.getRank())); - } - for (int dim = 0; dim < attrType.getRank(); ++dim) { - if (attrType.getShape()[dim] != resultType.getShape()[dim]) { - std::string msg; - raw_string_ostream os(msg); - return emitOpError( - "Shape mismatch between toy.constant return type and its " - "attribute at dimension " + - Twine(dim) + ": " + Twine(attrType.getShape()[dim]) + - " != " + Twine(resultType.getShape()[dim])); - } - } - } - return mlir::success(); -} - -void GenericCallOp::build(mlir::Builder *builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns a generic ToyArray initially - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.assign(arguments.begin(), arguments.end()); - auto calleeAttr = builder->getStringAttr(callee); - state.attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); -} - -mlir::LogicalResult GenericCallOp::verify() { - // Verify that every operand is a Toy Array - for (int opId = 0, num = getNumOperands(); opId < num; ++opId) { - if (!getOperand(opId)->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its " << opId << " operand, got " - << getOperand(opId)->getType(); - return emitOpError(os.str()); - } - } - return mlir::success(); -} - -/// Return the name of the callee. -StringRef GenericCallOp::getCalleeName() { - return getAttr("callee").cast().getValue(); -} - -template static mlir::LogicalResult verifyToySingleOperand(T *op) { - if (!op->getOperand()->getType().template isa()) { - std::string msg; - raw_string_ostream os(msg); - os << "expects a Toy Array for its argument, got " - << op->getOperand()->getType(); - return op->emitOpError(os.str()); - } - return mlir::success(); -} - -void ReturnOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Return does not return any value and has an optional single argument - if (value) - state.operands.push_back(value); -} - -mlir::LogicalResult ReturnOp::verify() { - if (getNumOperands() > 1) { - std::string msg; - raw_string_ostream os(msg); - os << "expects zero or one operand, got " << getNumOperands(); - return emitOpError(os.str()); - } - if (hasOperand() && failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void PrintOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - // Print does not return any value and has a single argument - state.operands.push_back(value); -} - -mlir::LogicalResult PrintOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void TransposeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(value); -} - -mlir::LogicalResult TransposeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - return mlir::success(); -} - -void ReshapeOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *value, ToyArrayType reshapedType) { - state.types.push_back(reshapedType); - state.operands.push_back(value); -} - -mlir::LogicalResult ReshapeOp::verify() { - if (failed(verifyToySingleOperand(this))) - return mlir::failure(); - auto retTy = getResult()->getType().dyn_cast(); - if (!retTy) - return emitOpError("toy.reshape is expected to produce a Toy array"); - if (retTy.isGeneric()) - return emitOpError("toy.reshape is expected to produce a shaped Toy array, " - "got a generic one."); - return mlir::success(); -} - -void AddOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult AddOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -void MulOp::build(mlir::Builder *builder, mlir::OperationState &state, - mlir::Value *lhs, mlir::Value *rhs) { - state.types.push_back(ToyArrayType::get(builder->getContext())); - state.operands.push_back(lhs); - state.operands.push_back(rhs); -} - -mlir::LogicalResult MulOp::verify() { - if (failed(verifyToyBinOperands(this))) - return mlir::failure(); - return mlir::success(); -} - -} // namespace toy diff --git a/mlir/examples/toy/Ch3/parser/AST.cpp b/mlir/examples/toy/Ch3/parser/AST.cpp index fde8b101e838..869f2ef2013d 100644 --- a/mlir/examples/toy/Ch3/parser/AST.cpp +++ b/mlir/examples/toy/Ch3/parser/AST.cpp @@ -125,7 +125,7 @@ void ASTDumper::dump(NumberExprAST *num) { llvm::errs() << num->getValue() << " " << loc(num) << "\n"; } -/// Helper to print recursively a literal. This handles nested array like: +/// Helper to print recurisvely a literal. This handles nested array like: /// [ [ 1, 2 ], [ 3, 4 ] ] /// We print out such array with the dimensions spelled out at every level: /// <2,2>[<2>[ 1, 2 ], <2>[ 3, 4 ] ] diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp index 1e7deb670c75..ad6cf6363143 100644 --- a/mlir/examples/toy/Ch3/toyc.cpp +++ b/mlir/examples/toy/Ch3/toyc.cpp @@ -22,12 +22,14 @@ #include "toy/Dialect.h" #include "toy/MLIRGen.h" #include "toy/Parser.h" -#include #include "mlir/Analysis/Verifier.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" @@ -61,6 +63,8 @@ static cl::opt emitAction( cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")), cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump"))); +static cl::opt EnableOpt("opt", cl::desc("Enable optimizations")); + /// Returns a Toy AST resulting from parsing the file or a nullptr on error. std::unique_ptr parseInputFile(llvm::StringRef filename) { llvm::ErrorOr> FileOrErr = @@ -75,9 +79,18 @@ std::unique_ptr parseInputFile(llvm::StringRef filename) { return parser.ParseModule(); } +mlir::LogicalResult optimize(mlir::ModuleOp module) { + mlir::PassManager pm(module.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); + + // Apply any generic pass manager command line options and run the pipeline. + applyPassManagerCLOptions(pm); + return pm.run(module); +} + int dumpMLIR() { // Register our Dialect with MLIR - mlir::registerDialect(); + mlir::registerDialect(); mlir::MLIRContext context; mlir::OwningModuleRef module; @@ -106,6 +119,12 @@ int dumpMLIR() { } if (!module) return 1; + if (EnableOpt) { + if (failed(optimize(*module))) { + llvm::errs() << "Module optimization failed\n"; + return 7; + } + } module->dump(); return 0; } @@ -125,6 +144,7 @@ int dumpAST() { } int main(int argc, char **argv) { + mlir::registerPassManagerCLOptions(); cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); switch (emitAction) { diff --git a/mlir/g3doc/Tutorials/Toy/Ch-3.md b/mlir/g3doc/Tutorials/Toy/Ch-3.md index 280f825f0b44..5305d15aadac 100644 --- a/mlir/g3doc/Tutorials/Toy/Ch-3.md +++ b/mlir/g3doc/Tutorials/Toy/Ch-3.md @@ -1,296 +1,262 @@ -# Chapter 3: Defining and Registering a Dialect in MLIR +# Chapter 3: High-level Language-Specific Analysis and Transformation -In the previous chapter, we saw how to emit a custom IR for Toy in MLIR using -opaque operations. In this chapter we will register our Dialect with MLIR to -start making the Toy IR more robust and friendly to use. +Creating a dialect that closely represents the semantics of an input language +enables analyses, transformations and optimizations in MLIR that require high +level language information and are generally performed on the language AST. For +example, `clang` has a fairly +[heavy mechanism](https://clang.llvm.org/doxygen/classclang_1_1TreeTransform.html) +for performing template instantiation in C++. -Dialects in MLIR allow for registering operations and types with an MLIRContext. -They also must reserve a "namespace" to avoid collision with other registered -dialects. These registered operations are no longer opaque to MLIR: for example -we can teach the MLIR verifier to enforce some invariants on the IR. +We divide compiler transformations into two categories: local and global. In +this chapter, we focus on how to leverage the Toy Dialect and its high-level +semantics to perform local pattern-match transformations that would be difficult +in LLVM. For this, we use MLIR's +[Generic DAG Rewriter](../../GenericDAGRewriter.md). -```c++ -/// This is the definition of the Toy dialect. A dialect inherits from -/// mlir::Dialect and registers custom operations and types (in its constructor). -/// It can also override general behavior of dialects exposed as virtual -/// methods, for example regarding verification and parsing/printing. -class ToyDialect : public mlir::Dialect { - public: - explicit ToyDialect(mlir::MLIRContext *ctx); +There are two methods that can be used to implement pattern-match +transformations: 1. Imperative, C++ pattern-match and rewrite 2. Declarative, +rule-based pattern-match and rewrite using +[Table-driven Declarative Rewrite Rule](../../DeclarativeRewrites.md) (DRR). +Note that the use of DRR requires that the operations be defined using ODS as +described in [Chapter 2](../Ch-2.md). - /// Parse a type registered to this dialect. Overriding this method is - /// required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - mlir::Type parseType(llvm::StringRef tyData, - mlir::Location loc) const override; +# Optimize Transpose using C++ style pattern-match and rewrite - /// Print a type registered to this dialect. Overriding this method is - /// only required for dialects that have custom types. - /// Technically this is only needed to be able to round-trip to textual IR. - void printType(mlir::Type type, llvm::raw_ostream &os) const override; -}; -``` +Let's start with a simple pattern and try to eliminate a sequence of two +transpose that cancel out: `transpose(transpose(X)) -> X`. Here is the +corresponding Toy example: -The dialect can now be registered in the global registry: - -```c++ - mlir::registerDialect(); -``` - -Any new `MLIRContext` created from now on will recognize the `toy` prefix when -parsing new types and invoke our `parseType` method. We will see later how to -enable custom operations, but first let's define a custom type to handle Toy -arrays. - -# Custom Type Handling - -As you may have noticed in the previous chapter, dialect specific types in MLIR -are serialized as strings. In the case of Toy, an example would be -`!toy.array<2, 3>`. MLIR will find the ToyDialect from the `!toy` prefix but it -is up to the dialect itself to translate the content of the string into a proper -type. - -First we need to define the class representing our type. In MLIR, types are -references to immutable and uniqued objects owned by the MLIRContext. As such, -our `ToyArrayType` will only be a wrapper around a pointer to an uniqued -instance of `ToyArrayTypeStorage` in the Context and provide the public facade -API to interact with the type. - -```c++ -class ToyArrayType : public mlir::Type::TypeBase { - public: - /// Returns the dimensions for this Toy array, or an empty range for a generic array. - llvm::ArrayRef getShape(); - - /// Predicate to test if this array is generic (shape haven't been inferred yet). - bool isGeneric() { return getShape().empty(); } - - /// Return the rank of this array (0 if it is generic) - int getRank() { return getShape().size(); } - - /// Get the unique instance of this Type from the context. - /// A ToyArrayType is only defined by the shape of the array. - static ToyArrayType get(mlir::MLIRContext *context, - llvm::ArrayRef shape = {}); - - /// Support method to enable LLVM-style RTTI type casting. - static bool kindof(unsigned kind) { return kind == ToyTypeKind::TOY_ARRAY; } -}; -``` - -Implementing `getShape()` for example is just about retrieving the pointer to -the uniqued instance and forwarding: - -```c++ -llvm::ArrayRef ToyArrayType::getShape() { - return getImpl()->getShape(); +```Toy(.toy) +def transpose_transpose(x) { + return transpose(transpose(x)); } ``` -The calls to `getImpl()` give access to the `ToyArrayTypeStorage` that holds the -information for this type. For details about how the storage of the type works, -we'll refer you to `Ch3/mlir/ToyDialect.cpp`. - -Finally, the Toy dialect can register the type with MLIR, and implement some -custom parsing for our types: - -```c++ -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) - // note the `toy` prefix that we reserve here. - : mlir::Dialect("toy", ctx) { - // Register our custom type with MLIR. - addTypes(); -} - -/// Parse a type registered to this dialect, we expect only Toy arrays. -mlir::Type ToyDialect::parseType(StringRef tyData, - mlir::Location loc) const { - // Sanity check: we only support array or array<...> - if (!tyData.startswith("array")) { - getContext()->emitError(loc, "Invalid Toy type '" + tyData + - "', array expected"); - return nullptr; - } - // Drop the "array" prefix from the type name, we expect either an empty - // string or just the shape. - tyData = tyData.drop_front(StringRef("array").size()); - // This is the generic array case without shape, early return it. - if (tyData.empty()) - return ToyArrayType::get(getContext()); - - // Use a regex to parse the shape (for efficient we should store this regex in - // the dialect itself). - SmallVector matches; - auto shapeRegex = llvm::Regex("^<([0-9]+)(, ([0-9]+))*>$"); - if (!shapeRegex.match(tyData, &matches)) { - getContext()->emitError(loc, "Invalid toy array shape '" + tyData + "'"); - return nullptr; - } - SmallVector shape; - // Iterate through the captures, skip the first one which is the full string. - for (auto dimStr : - llvm::make_range(std::next(matches.begin()), matches.end())) { - if (dimStr.startswith(",")) - continue; // POSIX misses non-capturing groups. - if (dimStr.empty()) - continue; // '*' makes it an optional group capture - // Convert the capture to an integer - unsigned long long dim; - if (getAsUnsignedInteger(dimStr, /* Radix = */ 10, dim)) { - getContext()->emitError(loc, Twine("Couldn't parse dimension as integer, matched: ") + dimStr); - return mlir::Type(); - } - shape.push_back(dim); - } - // Finally we collected all the dimensions in the shape, - // create the array type. - return ToyArrayType::get(getContext(), shape); -} -``` - -And we also update our IR generation from the Toy AST to use our new type -instead of an opaque one: - -```c++ -template mlir::Type getType(T shape) { - SmallVector shape64(shape.begin(), shape.end()); - return ToyArrayType::get(&context, shape64); -} -``` - -From now on, MLIR knows how to parse types that are wrapped in `!toy<...>` and -these won't be opaque anymore. The first consequence is that bogus IR with -respect to our type won't be loaded anymore: - -```bash(.sh) -$ echo 'func @foo() -> !toy<"bla">' | toyc -emit=mlir -x mlir - -loc("":1:21): error: Invalid Toy type 'bla', array expected -$ echo 'func @foo() -> !toy<"array<>">' | toyc -emit=mlir -x mlir - -loc("":1:21): error: Invalid toy array shape '<>' -$ echo 'func @foo() -> !toy<"array<1, >">' | toyc -emit=mlir -x mlir - -loc("":1:21): error: Invalid toy array shape '<1, >' -$ echo 'func @foo() -> !toy<"array<1, 2, 3>">' | toyc -emit=mlir -x mlir - -func @foo() -> !toy<"array<1, 3>"> -``` - -## Defining a C++ Class for an Operation - -After defining our custom type, we will register all the operations for the Toy -language. Let's walk through the creation of the `toy.generic_call` operation: +Which corresponds to the following IR: ```MLIR(.mlir) - %4 = "toy.generic_call"(%1, %3) {callee: "my_func"} - : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy<"array"> +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> +attributes {toy.generic} { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%1) : (tensor<*xf64>) -> () +} ``` -This operation takes a variable number of operands, all of which are expected to -be Toy arrays, and return a single result. An operation inherit from `mlir::Op` -and add some optional *traits* to customize its behavior. +This is a good example of a transformation that is trivial to match on the Toy +IR but that would be quite hard for LLVM to figure. For example today clang +can't optimize away the temporary array and the computation with the naive +transpose expressed with these loops: ```c++ -class GenericCallOp - : public mlir::Op { +#define N 100 +#define M 100 - public: - /// MLIR will use this to register the operation with the parser/printer. - static llvm::StringRef getOperationName() { return "toy.generic_call"; } +void sink(void *); +void double_transpose(int A[N][M]) { + int B[M][N]; + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + B[j][i] = A[i][j]; + } + } + for(int i = 0; i < N; ++i) { + for(int j = 0; j < M; ++j) { + A[i][j] = B[j][i]; + } + } + sink(A); +} +``` - /// Operations can add custom verification beyond the traits they define. - /// We will ensure that all the operands are Toy arrays. - bool verify(); +For a simple C++ approach to rewrite involving matching a tree-like pattern in +the IR and replacing it with a different set of operations, we can plug into the +MLIR `Canonicalizer` pass by implementing a `RewritePattern`: - /// Interface to the builder to allow: - /// mlir::OpBuilder::create(...) - /// This method populate the `state` that MLIR use to create operations. - /// The `toy.generic_call` operation accepts a callee name and a list of - /// arguments for the call. - static void build(mlir::OpBuilder &builder, mlir::OperationState &state, - llvm::StringRef callee, - llvm::ArrayRef arguments); +```c++ +/// Fold transpose(transpose(x) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { + /// We register this pattern to match every toy.transpose in the IR. + /// The "benefit" is used by the framework to order the patterns and process + /// them in order of profitability. + SimplifyRedundantTranspose(mlir::MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1) {} - /// Return the name of the callee by fetching it from the attribute. - llvm::StringRef getCalleeName(); + /// This method is attempting to match a pattern and rewrite it. The rewriter + /// argument is the orchestrator of the sequence of rewrites. It is expected + /// to interact with it to perform any changes to the IR from here. + mlir::PatternMatchResult + matchAndRewrite(TransposeOp op, + mlir::PatternRewriter &rewriter) const override { + // Look through the input of the current transpose. + mlir::Value *transposeInput = op.getOperand(); + TransposeOp transposeInputOp = + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + // If the input is defined by another Transpose, bingo! + if (!transposeInputOp) + return matchFailure(); - private: - using Op::Op; + // Use the rewriter to perform the replacement + rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); + return matchSuccess(); + } }; ``` -and we register this operation in the `ToyDialect` constructor: +The implementation of this rewriter is in `ToyCombine.cpp`. The +[canonicalization pass](../../Canonicalization.md) applies transformations +defined by operations in a greedy, iterative manner. To ensure that the +canonicalization pass applies our new transform, we set +[hasCanonicalizer = 1](../../OpDefinitions.md#hascanonicalizer) and register the +pattern with the canonicalization framework. ```c++ -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) : mlir::Dialect("toy", ctx) { - addOperations(); - addTypes(); +// Register our patterns for rewrite by the Canonicalization framework. +void TransposeOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); } ``` -After creating classes for each of our operations, our dialect is ready and we -have now better invariants enforced in our IR, and nicer API to implement -analyses and transformations in the [next chapter](Ch-4.md). - -## Using TableGen - -FIXME: complete - -## Revisiting the Builder API - -We can now update `MLIRGen.cpp`, previously our use of the builder was very -generic and creating a call operation looked like: - -``` - // Calls to user-defined function are mapped to a custom call that takes - // the callee name as an attribute. - mlir::OperationState result(&context, location, "toy.generic_call"); - result.types.push_back(getType(VarType{})); - result.operands = std::move(operands); - for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); - if (!arg) - return nullptr; - result.operands.push_back(arg); - } - auto calleeAttr = builder->getStringAttr(call.getCallee()); - result.attributes.push_back(builder->getNamedAttr("callee", calleeAttr)); - return builder->createOperation(result)->getResult(0); -``` - -We replace it with this new version: +We also need to update our main file, `toyc.cpp`, to add an optimization +pipeline. In MLIR, the optimizations are ran through a `PassManager` in a +similar way to LLVM: ```c++ - for (auto &expr : call.getArgs()) { - auto *arg = mlirGen(*expr); - if (!arg) - return nullptr; - operands.push_back(arg); - } - return builder->create(location, call.getCallee(), operands)->getResult(); + mlir::PassManager pm(module.getContext()); + pm.addPass(mlir::createCanonicalizerPass()); ``` -This interface offers better type safety, with some invariant enforced at the -API level. For instance the `GenericCallOp` exposes now a `getResult()` method -that does not take any argument, while before MLIR assumed the general cases and -left open the possibility to have multiple returned values. The API was -`getResult(int resultNum)`. +Finally, we can try to run `toyc test/transpose_transpose.toy -emit=mlir -opt` +and observe our pattern in action: -# Putting It All Together - -After writing a class for each of our operation and implementing custom -verifier, we try again the same example of invalid IR from the previous chapter: - -```bash(.sh) -$ cat test/invalid.mlir -func @main() { - %0 = "toy.print"() : () -> !toy.array<2, 3> +```MLIR(.mlir) +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> +attributes {toy.generic} { + %0 = "toy.transpose"(%arg0) : (tensor<*xf64>) -> tensor<*xf64> + %1 = "toy.transpose"(%0) : (tensor<*xf64>) -> tensor<*xf64> + "toy.return"(%1) : (tensor<*xf64>) -> () } -$ toyc test/invalid.mlir -emit=mlir -loc("test/invalid.mlir":2:8): error: 'toy.print' op requires a single operand ``` -This time the IR is correctly rejected by the verifier! +As expected we now directly return the function argument, bypassing any +transpose operation. However one of the transpose hasn't been eliminated. That +is not ideal! What happened is that our pattern replaced the last transform with +the function input and left behind the now dead transpose input. The +Canonicalizer knows to cleanup dead operations, however MLIR conservatively +assumes that operations may have side-effects. We can fix it by adding a new +trait, `NoSideEffect`, to our `TransposeOp`: -In the [next chapter](Ch-4.md) we will leverage our new dialect to implement -some high-level language-specific analyses and transformations for the Toy -language. +```TableGen(.td): +def TransposeOp : Toy_Op<"transpose", [NoSideEffect]> {...} +``` + +Let's retry now `toyc test/transpose_transpose.toy -emit=mlir -opt`: + +```MLIR(.mlir) +func @transpose_transpose(%arg0: tensor<*xf64>) -> tensor<*xf64> +attributes {toy.generic} { + "toy.return"(%arg0) : (tensor<*xf64>) -> () +} +``` + +Perfect! No `transpose` operation is left, the code is optimal. + +In the next section, we use DRR for pattern match optimizations associated with +the Reshape op. + +# Optimize Reshapes using DRR + +Declarative, rule-based pattern-match and rewrite or DRR is an operation +DAG-based declarative rewriter that provides a table-based syntax for +pattern-match and rewrite rules: + +```TableGen(.td): +class Pattern< + dag sourcePattern, list resultPatterns, + list additionalConstraints = [], + dag benefitsAdded = (addBenefit 0)>; +``` + +A redundant reshape optimization similar to SimplifyRedundantTranspose can be +expressed more simply using DRR as follows: + +```TableGen(.td): +// Reshape(Reshape(x)) = x +def ReshapeReshapeOptPattern : Pat<(ReshapeOp(ReshapeOp $arg)), + (ReshapeOp $arg)>; +``` + +The automatically generated C++ code corresponding to each of the DRR patterns +can be found under $BUILD_DIR/projects/mlir/examples/toy/Ch3/ToyCombine.inc. + +DRR also provides a method for adding argument constraints when the +transformation is conditional on some properties of the arguments and results. +An example is a transformation that eliminates reshapes when they are redundant, +i.e. when the input and output shapes are identical. + +```TableGen(.td): +def TypesAreIdentical : ConstraintgetType() == $1->getType()">>; +def RedundantReshapeOptPattern : Pat< + (ReshapeOp:$res $arg), (replaceWithValue $arg), + [(TypesAreIdentical $res, $arg)]>; +``` + +Some optimizations may require additional transformations on instruction +arguments. This is achieved using NativeCodeCall, which allows for more complex +transformations either by calling into a C++ helper function or by using inline +C++. An example of such an optimization is FoldConstantReshape, where we +optimize Reshape of a constant value by reshaping the constant in place and +eliminating the reshape operation. + +```TableGen(.td): +def ReshapeConstant : NativeCodeCall<"$0.reshape(($1->getType()).cast())">; +def FoldConstantReshapeOptPattern : Pat< + (ReshapeOp:$res (ConstantOp $arg)), + (ConstantOp (ReshapeConstant $arg, $res))>; +``` + +We demonstrate these reshape optimizations using the following +trivialReshape.toy program: + +```c++ +def main() { + var a<2,1> = [1, 2]; + var b<2,1> = a; + var c<2,1> = b; + print(c); +} +``` + +```MLIR(.mlir) +module { + func @main() { + %0 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf64>} + : () -> tensor<2xf64> + %1 = "toy.reshape"(%0) : (tensor<2xf64>) -> tensor<2x1xf64> + %2 = "toy.reshape"(%1) : (tensor<2x1xf64>) -> tensor<2x1xf64> + %3 = "toy.reshape"(%2) : (tensor<2x1xf64>) -> tensor<2x1xf64> + "toy.print"(%3) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +We can try to run `toyc test/trivialReshape.toy -emit=mlir -opt` and observe our +pattern in action: + +```MLIR(.mlir) +module { + func @main() { + %0 = "toy.constant"() {value = dense<[[1.000000e+00], [2.000000e+00]]> \ + : tensor<2x1xf64>} : () -> tensor<2x1xf64> + "toy.print"(%0) : (tensor<2x1xf64>) -> () + "toy.return"() : () -> () + } +} +``` + +As expected, no reshape operations remain after canonicalization. + +Further details on the declarative rewrite method can be found at +[Table-driven Declarative Rewrite Rule (DRR)](../../DeclarativeRewrites.md). diff --git a/mlir/test/Examples/Toy/Ch3/codegen.toy b/mlir/test/Examples/Toy/Ch3/codegen.toy index d4949cd013e5..7103e549ca3f 100644 --- a/mlir/test/Examples/Toy/Ch3/codegen.toy +++ b/mlir/test/Examples/Toy/Ch3/codegen.toy @@ -13,20 +13,19 @@ def main() { print(d); } -# CHECK-LABEL: func @multiply_transpose(%arg0: !toy.array, %arg1: !toy.array) -# CHECK-NEXT: attributes {toy.generic = true} { -# CHECK-NEXT: %0 = "toy.transpose"(%arg1) : (!toy.array) -> !toy.array -# CHECK-NEXT: %1 = "toy.mul"(%arg0, %0) : (!toy.array, !toy.array) -> !toy.array -# CHECK-NEXT: "toy.return"(%1) : (!toy.array) -> () -# CHECK-NEXT: } - -# CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> !toy.array<2, 3> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<2, 3>) -> !toy.array<2, 3> -# CHECK-NEXT: %2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> !toy.array<6> -# CHECK-NEXT: %3 = "toy.reshape"(%2) : (!toy.array<6>) -> !toy.array<2, 3> -# CHECK-NEXT: %4 = "toy.generic_call"(%1, %3) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: %5 = "toy.generic_call"(%3, %1) {callee = "multiply_transpose"} : (!toy.array<2, 3>, !toy.array<2, 3>) -> !toy.array -# CHECK-NEXT: "toy.print"(%5) : (!toy.array) -> () -# CHECK-NEXT: "toy.return"() : () -> () +# CHECK-LABEL: func @multiply_transpose( +# CHECK-SAME: [[VAL_0:%.*]]: tensor<*xf64>, [[VAL_1:%.*]]: tensor<*xf64>) +# CHECK-NEXT: attributes {toy.generic} { +# CHECK-NEXT: [[VAL_2:%.*]] = "toy.transpose"([[VAL_1]]) : (tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_3:%.*]] = "toy.mul"([[VAL_0]], [[VAL_2]]) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.return"([[VAL_3]]) : (tensor<*xf64>) -> () +# CHECK-LABEL: func @main() { +# CHECK-NEXT: [[VAL_4:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_5:%.*]] = "toy.reshape"([[VAL_4]]) : (tensor<2x3xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_6:%.*]] = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64> +# CHECK-NEXT: [[VAL_7:%.*]] = "toy.reshape"([[VAL_6]]) : (tensor<6xf64>) -> tensor<2x3xf64> +# CHECK-NEXT: [[VAL_8:%.*]] = "toy.generic_call"([[VAL_5]], [[VAL_7]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: [[VAL_9:%.*]] = "toy.generic_call"([[VAL_7]], [[VAL_5]]) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> +# CHECK-NEXT: "toy.print"([[VAL_9]]) : (tensor<*xf64>) -> () +# CHECK-NEXT: "toy.return"() : () -> () diff --git a/mlir/test/Examples/Toy/Ch3/invalid.mlir b/mlir/test/Examples/Toy/Ch3/invalid.mlir index e750b52a6402..558cf9c154d2 100644 --- a/mlir/test/Examples/Toy/Ch3/invalid.mlir +++ b/mlir/test/Examples/Toy/Ch3/invalid.mlir @@ -1,11 +1,9 @@ // RUN: not toyc-ch3 %s -emit=mlir 2>&1 - -// This IR is not "valid": +// The following IR is not "valid": // - toy.print should not return a value. // - toy.print should take an argument. // - There should be a block terminator. -// This all round-trip since this is opaque for MLIR. func @main() { - %0 = "toy.print"() : () -> !toy.array<2, 3> + %0 = "toy.print"() : () -> tensor<2x3xf64> } diff --git a/mlir/test/Examples/Toy/Ch3/scalar.toy b/mlir/test/Examples/Toy/Ch3/scalar.toy index 97863c104589..dd7ec935e88b 100644 --- a/mlir/test/Examples/Toy/Ch3/scalar.toy +++ b/mlir/test/Examples/Toy/Ch3/scalar.toy @@ -6,9 +6,9 @@ def main() { } # CHECK-LABEL: func @main() { -# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor<1xf64>} : () -> !toy.array<1> -# CHECK-NEXT: %1 = "toy.reshape"(%0) : (!toy.array<1>) -> !toy.array<2, 2> -# CHECK-NEXT: "toy.print"(%1) : (!toy.array<2, 2>) -> () +# CHECK-NEXT: %0 = "toy.constant"() {value = dense<5.500000e+00> : tensor} : () -> tensor +# CHECK-NEXT: %1 = "toy.reshape"(%0) : (tensor) -> tensor<2x2xf64> +# CHECK-NEXT: "toy.print"(%1) : (tensor<2x2xf64>) -> () # CHECK-NEXT: "toy.return"() : () -> () # CHECK-NEXT: }