Add support for generating operation interfaces from the ODS framework.

Operation interfaces generally require a bit of boilerplate code to connect all of the pieces together. This cl introduces mechanisms in the ODS to allow for generating operation interfaces via the 'OpInterface' class.

Providing a definition of the `OpInterface` class will auto-generate the c++
classes for the interface. An `OpInterface` includes a name, for the c++ class,
along with a list of interface methods. There are two types of methods that can be used with an interface, `InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the same core components, with the distinction that `StaticInterfaceMethod` models a static method on the derived operation.

An `InterfaceMethod` is comprised of the following components:
    * ReturnType
      - A string corresponding to the c++ return type of the method.
    * MethodName
      - A string corresponding to the desired name of the method.
    * Arguments
      - A dag of strings that correspond to a c++ type and variable name
        respectively.
    * MethodBody (Optional)
      - An optional explicit implementation of the interface method.

def MyInterface : OpInterface<"MyInterface"> {
  let methods = [
    // A simple non-static method with no inputs.
    InterfaceMethod<"unsigned", "foo">,

    // A new non-static method accepting an input argument.
    InterfaceMethod<"Value *", "bar", (ins "unsigned":$i)>,

    // Query a static property of the derived operation.
    StaticInterfaceMethod<"unsigned", "fooStatic">,

    // Provide the definition of a static interface method.
    // Note: `ConcreteOp` corresponds to the derived operation typename.
    StaticInterfaceMethod<"Operation *", "create",
      (ins "OpBuilder &":$builder, "Location":$loc), [{
        return builder.create<ConcreteOp>(loc);
    }]>,

    // Provide a definition of the non-static method.
    // Note: `op` corresponds to the derived operation variable.
    InterfaceMethod<"unsigned", "getNumInputsAndOutputs", (ins), [{
      return op.getNumInputs() + op.getNumOutputs();
    }]>,
  ];

PiperOrigin-RevId: 264754898
This commit is contained in:
River Riddle
2019-08-21 20:57:23 -07:00
committed by A. Unique TensorFlower
parent 85bc8655f0
commit b9377d7ec6
9 changed files with 437 additions and 154 deletions

View File

@@ -171,3 +171,21 @@ Operation *op = ...;
if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
llvm::errs() << "num inputs = " << example.getNumInputs() << "\n";
```
#### Utilizing the ODS Framework
Operation interfaces require a bit of boiler plate to connect all of the pieces
together. The ODS(Operation Definition Specification) framework provides
simplified mechanisms for
[defining interfaces](OpDefinitions.md#operation-interfaces).
As an example, using the ODS framework would allow for defining the example
interface above as:
```tablegen
def ExampleOpInterface : OpInterface<"ExampleOpInterface"> {
let methods = [
InterfaceMethod<"unsigned", "getNumInputs">,
];
}
```

View File

@@ -281,10 +281,81 @@ same operation.
Traits are operation properties that affect syntax or semantics. MLIR C++
models various traits in the `mlir::OpTrait` namespace.
Both operation traits and constraints involving multiple
operands/attributes/results are provided as the second template parameter to the
`Op` class. They should be deriving from the `OpTrait` class. See
[Constraints](#constraints) for more information.
Both operation traits, [interfaces](#operation-interfaces), and constraints
involving multiple operands/attributes/results are provided as the second
template parameter to the `Op` class. They should be deriving from the `OpTrait`
class. See [Constraints](#constraints) for more information.
### Operation interfaces
[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by
which to opaquely call methods and access information on an *Op instance,
without knowing the exact operation type. Operation interfaces defined in C++
can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside
from using pre-existing interfaces in the C++ API, the ODS framework also
provides a simplified mechanism for defining such interfaces; that removes much
of the boilerplate necessary.
Providing a definition of the `OpInterface` class will auto-generate the C++
classes for the interface. An `OpInterface` includes a name, for the C++ class,
along with a list of interface methods.
```tablegen
def MyInterface : OpInterface<"MyInterface"> {
let methods = [...];
}
```
There are two types of methods that can be used with an interface,
`InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the
same core components, with the distinction that `StaticInterfaceMethod` models a
static method on the derived operation.
An `InterfaceMethod` is comprised of the following components:
* ReturnType
- A string corresponding to the C++ return type of the method.
* MethodName
- A string corresponding to the desired name of the method.
* Arguments (Optional)
- A dag of strings that correspond to a C++ type and variable name
respectively.
* MethodBody (Optional)
- An optional explicit implementation of the interface method.
- `ConcreteOp` is an implicitly defined typename that can be used to refer
to the type of the derived operation currently being operated on.
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
used to refer to an instance of the derived operation.
Examples:
```tablegen
def MyInterface : OpInterface<"MyInterface"> {
let methods = [
// A simple non-static method with no inputs.
InterfaceMethod<"unsigned", "foo">,
// A new non-static method accepting an input argument.
InterfaceMethod<"Value *", "bar", (ins "unsigned":$i)>,
// Query a static property of the derived operation.
StaticInterfaceMethod<"unsigned", "fooStatic">,
// Provide the definition of a static interface method.
// Note: `ConcreteOp` corresponds to the derived operation typename.
StaticInterfaceMethod<"Operation *", "create",
(ins "OpBuilder &":$builder, "Location":$loc), [{
return builder.create<ConcreteOp>(loc);
}]>,
// Provide a definition of the non-static method.
// Note: `op` corresponds to the derived operation variable.
InterfaceMethod<"unsigned", "getNumInputsAndOutputs", (ins), [{
return op.getNumInputs() + op.getNumOutputs();
}]>,
];
}
```
### Custom builder methods

View File

@@ -5,4 +5,6 @@ add_public_tablegen_target(MLIRLinalgOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
mlir_tablegen(LinalgLibraryOpInterfaces.h.inc -gen-op-interface-decls)
mlir_tablegen(LinalgLibraryOpInterfaces.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)

View File

@@ -70,7 +70,49 @@ def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
// The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp'
// interface.
def LinalgLibraryInterface : NativeOpInterface<"LinalgOp">;
def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
let methods = [
/// Query the number of inputs and outputs from the operation.
InterfaceMethod<"unsigned", "getNumInputs">,
InterfaceMethod<"unsigned", "getNumOutputs">,
InterfaceMethod<"unsigned", "getNumInputsAndOutputs">,
InterfaceMethod<"Operation::operand_range", "getInputs">,
InterfaceMethod<"Operation::operand_range", "getOutputs">,
InterfaceMethod<"Operation::operand_range", "getInputsAndOutputs">,
/// Query the number of each type of loop.
InterfaceMethod<"unsigned", "getNumParallelLoops">,
InterfaceMethod<"unsigned", "getNumReductionLoops">,
InterfaceMethod<"unsigned", "getNumWindowLoops">,
InterfaceMethod<"unsigned", "getNumLoops", (ins), [{
return op.getNumParallelLoops() + op.getNumReductionLoops() +
op.getNumWindowLoops();
}]>,
/// Get a specific input/output at the given index.
InterfaceMethod<"Value *", "getInput", (ins "unsigned":$i)>,
InterfaceMethod<"Value *", "getOutput", (ins "unsigned":$i)>,
/// Get the index of the given value, or None if the value is not an input.
InterfaceMethod<"llvm::Optional<unsigned>", "getIndexOfInput",
(ins "Value *":$view)>,
InterfaceMethod<"llvm::Optional<unsigned>", "getIndexOfOutput",
(ins "Value *":$view)>,
/// Get the view type of the input/output at the given index.
InterfaceMethod<"ViewType", "getInputViewType", (ins "unsigned":$i)>,
InterfaceMethod<"ViewType", "getOutputViewType", (ins "unsigned":$i)>,
/// Create an operation with the given location and operands.
StaticInterfaceMethod<"Operation *", "create",
(ins "OpBuilder &":$builder, "Location":$loc,
"ArrayRef<Value *>":$operands,
"ArrayRef<NamedAttribute>":$attributes), [{
return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
attributes);
}]>
];
}
// Base Tablegen class for Linalg ops.
// Linalg ops that correspond to library calls operate on linalg::View as their

View File

@@ -73,147 +73,7 @@ std::string generateLibraryCallName(Operation *op);
/// Only permutation maps are currently supported.
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
namespace detail {
struct LinalgOpInterfaceTraits {
struct Concept {
virtual ~Concept() = default;
virtual unsigned getNumInputs(Operation *op) = 0;
virtual unsigned getNumOutputs(Operation *op) = 0;
virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
virtual unsigned getNumParallelLoops(Operation *op) = 0;
virtual unsigned getNumReductionLoops(Operation *op) = 0;
virtual unsigned getNumWindowLoops(Operation *op) = 0;
virtual Value *getInput(Operation *op, unsigned i) = 0;
virtual llvm::Optional<unsigned> getIndexOfInput(Operation *op,
Value *view) = 0;
virtual ViewType getInputViewType(Operation *op, unsigned i) = 0;
virtual Operation::operand_range getInputs(Operation *op) = 0;
virtual Value *getOutput(Operation *op, unsigned i) = 0;
virtual llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
Value *view) = 0;
virtual ViewType getOutputViewType(Operation *op, unsigned i) = 0;
virtual Operation::operand_range getOutputs(Operation *op) = 0;
virtual Operation::operand_range getInputsAndOutputs(Operation *op) = 0;
virtual Operation *create(OpBuilder &builder, Location loc,
ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) = 0;
};
template <typename ConcreteOp> struct Model : public Concept {
unsigned getNumInputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputs();
}
unsigned getNumOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumOutputs();
}
unsigned getNumInputsAndOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getNumInputsAndOutputs();
}
unsigned getNumParallelLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumParallelLoops();
}
unsigned getNumReductionLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumReductionLoops();
}
unsigned getNumWindowLoops(Operation *op) override {
return cast<ConcreteOp>(op).getNumWindowLoops();
}
Value *getInput(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getInput(i);
}
llvm::Optional<unsigned> getIndexOfInput(Operation *op,
Value *view) override {
return cast<ConcreteOp>(op).getIndexOfInput(view);
}
ViewType getInputViewType(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getInputViewType(i);
}
Operation::operand_range getInputs(Operation *op) override {
return cast<ConcreteOp>(op).getInputs();
}
Value *getOutput(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getOutput(i);
}
llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
Value *view) override {
return cast<ConcreteOp>(op).getIndexOfOutput(view);
}
ViewType getOutputViewType(Operation *op, unsigned i) override {
return cast<ConcreteOp>(op).getOutputViewType(i);
}
Operation::operand_range getOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getOutputs();
}
Operation::operand_range getInputsAndOutputs(Operation *op) override {
return cast<ConcreteOp>(op).getInputsAndOutputs();
}
Operation *create(OpBuilder &builder, Location loc,
ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) override {
return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
attributes);
}
};
};
} // namespace detail
/// A LinalgOp behaves like a base class for the Linalg operations that are
/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
/// method and dispatches to the appropriate LinalgLibraryOp.
/// This allows writing generic passes, like tiling, for all current and future
/// LinalgOps without requiring templating and dispatch in multiple places.
class LinalgOp : public OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits> {
public:
using OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits>::OpInterface;
unsigned getNumParallelLoops() {
return getImpl()->getNumParallelLoops(getOperation());
}
unsigned getNumReductionLoops() {
return getImpl()->getNumReductionLoops(getOperation());
}
unsigned getNumWindowLoops() {
return getImpl()->getNumWindowLoops(getOperation());
}
unsigned getNumLoops() {
return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
}
unsigned getNumInputs() { return getImpl()->getNumInputs(getOperation()); }
unsigned getNumOutputs() { return getImpl()->getNumOutputs(getOperation()); }
unsigned getNumInputsAndOutputs() {
return getImpl()->getNumInputsAndOutputs(getOperation());
}
Value *getInput(unsigned i) { return getImpl()->getInput(getOperation(), i); }
llvm::Optional<unsigned> getIndexOfInput(Value *view) {
return getImpl()->getIndexOfInput(getOperation(), view);
}
ViewType getInputViewType(unsigned i) {
return getImpl()->getInputViewType(getOperation(), i);
}
Operation::operand_range getInputs() {
return getImpl()->getInputs(getOperation());
}
Value *getOutput(unsigned i) {
return getImpl()->getOutput(getOperation(), i);
}
llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
return getImpl()->getIndexOfOutput(getOperation(), view);
}
ViewType getOutputViewType(unsigned i) {
return getImpl()->getOutputViewType(getOperation(), i);
}
Operation::operand_range getOutputs() {
return getImpl()->getOutputs(getOperation());
}
Operation::operand_range getInputsAndOutputs() {
return getImpl()->getInputsAndOutputs(getOperation());
}
LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
ArrayRef<NamedAttribute> attributes) {
return LinalgOp(getImpl()->create(builder, loc, operands, attributes));
}
};
#include "mlir/Linalg/IR/LinalgLibraryOpInterfaces.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"

View File

@@ -1089,22 +1089,52 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
// OpInterface definitions
//===----------------------------------------------------------------------===//
// NativeOpInterface corresponds to a specific 'OpInterface' class defined in
// Marker used to identify the argument list for an op or interface method.
def ins;
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
// C++. The purpose to wrap around C++ symbol string with this class is to make
// interfaces specified for ops in TableGen less alien and more integrated.
class NativeOpInterface<string prop> : NativeOpTrait<""> {
// TODO(riverriddle) Remove when operation interfaces have their own trait
// subclass.
let trait = prop # "::Trait";
class OpInterfaceTrait<string name> : NativeOpTrait<""> {
let trait = name # "::Trait";
}
// This class represents a single, optionally static, interface method.
// Note: non-static interface methods have an implicit 'op' parameter
// corresponding to an instance of the derived operation.
class InterfaceMethod<string retTy, string methodName,
dag args = (ins), code methodBody = [{}]> {
/// The name of the interface method.
string name = methodName;
/// The c++ type-name of the return type.
string returnType = retTy;
/// A dag of string that correspond to the arguments of the method.
dag arguments = args;
/// An optional body to the method.
code body = methodBody;
}
// This class represents a single static interface method.
class StaticInterfaceMethod<string retTy, string methodName,
dag args = (ins), code methodBody = [{}]>
: InterfaceMethod<retTy, methodName, args, methodBody>;
// OpInterface represents an interface regarding an op.
class OpInterface<string name> : OpInterfaceTrait<name> {
// The name given to the c++ interface class.
string cppClassName = name;
/// The list of methods defined by this interface.
list<InterfaceMethod> methods = [];
}
//===----------------------------------------------------------------------===//
// Op definitions
//===----------------------------------------------------------------------===//
// Marker used to identify the argument list for an op.
def ins;
// Marker used to identify the result list for an op.
def outs;

View File

@@ -829,6 +829,8 @@ llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os,
namespace mlir {
namespace linalg {
#include "mlir/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"

View File

@@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR
mlir-tblgen.cpp
OpDefinitionsGen.cpp
OpDocGen.cpp
OpInterfacesGen.cpp
ReferenceImplGen.cpp
RewriterGen.cpp
SPIRVUtilsGen.cpp

View File

@@ -0,0 +1,257 @@
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// OpInterfacesGen generates definitions for operation interfaces.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/STLExtras.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
namespace {
// This struct represents a single method argument.
struct MethodArgument {
StringRef type, name;
};
// Wrapper class around a single interface method.
class OpInterfaceMethod {
public:
explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
llvm::DagInit *args = def->getValueAsDag("arguments");
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
arguments.push_back(
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
args->getArgNameStr(i)});
}
}
// Return the return type of this method.
StringRef getReturnType() const {
return def->getValueAsString("returnType");
}
// Return the name of this method.
StringRef getName() const { return def->getValueAsString("name"); }
// Return if this method is static.
bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const {
auto value = def->getValueAsString("body");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
// Arguments.
ArrayRef<MethodArgument> getArguments() const { return arguments; }
bool arg_empty() const { return arguments.empty(); }
protected:
// The TableGen definition of this method.
const llvm::Record *def;
// The arguments of this method.
SmallVector<MethodArgument, 2> arguments;
};
// Wrapper class with helper methods for accessing OpInterfaces defined in
// TableGen.
class OpInterface {
public:
explicit OpInterface(const llvm::Record *def) : def(def) {
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
for (llvm::Init *init : listInit->getValues())
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
}
// Return the name of this interface.
StringRef getName() const { return def->getValueAsString("cppClassName"); }
// Return the methods of this interface.
ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
protected:
// The TableGen definition of this interface.
const llvm::Record *def;
// The methods of this interface.
SmallVector<OpInterfaceMethod, 8> methods;
};
} // end anonymous namespace
// Emit the method name and argument list for the given method. If
// 'addOperationArg' is true, then an Operation* argument is added to the
// beginning of the argument list.
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
raw_ostream &os, bool addOperationArg) {
os << method.getName() << '(';
if (addOperationArg)
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
os << arg.type << " " << arg.name;
});
os << ')';
}
static void emitInterfaceDef(const Record &interfaceDef, raw_ostream &os) {
OpInterface interface(&interfaceDef);
StringRef interfaceName = interface.getName();
// Insert the method definitions.
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
for (Init *init : listInit->getValues()) {
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
os << method.getReturnType() << " " << interfaceName << "::";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
// Forward to the method on the concrete operation type.
os << " {\n return getImpl()->" << method.getName() << '(';
if (!method.isStatic())
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
interleaveComma(method.getArguments(), os,
[&](const MethodArgument &arg) { os << arg.name; });
os << ");\n }\n";
}
}
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Operation Interface Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
for (const auto *def : defs)
emitInterfaceDef(*def, os);
return false;
}
static void emitConceptDecl(const Record &interfaceDef, raw_ostream &os) {
os << " class Concept {\n"
<< " public:\n"
<< " virtual ~Concept() = default;\n";
// Insert each of the virtual methods.
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
for (Init *init : listInit->getValues()) {
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
// In the concept, all methods are pure virtual.
os << " virtual " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
os << " = 0;\n";
}
os << " };\n";
}
static void emitModelDecl(const Record &interfaceDef, raw_ostream &os) {
os << " template<typename ConcreteOp>\n";
os << " class Model : public Concept {\npublic:\n";
// Insert each of the virtual method overrides.
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
for (Init *init : listInit->getValues()) {
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
os << " " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
os << " final {\n";
// Provide a definition of the concrete op if this is non static.
if (!method.isStatic()) {
os << " auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n"
<< " (void)op;\n";
}
// Check for a provided body to the function.
if (auto body = method.getBody()) {
os << body << "\n }\n";
continue;
}
// Forward to the method on the concrete operation type.
os << " return " << (method.isStatic() ? "ConcreteOp::" : "op.");
// Add the arguments to the call.
os << method.getName() << '(';
interleaveComma(method.getArguments(), os,
[&](const MethodArgument &arg) { os << arg.name; });
os << ");\n }\n";
}
os << " };\n";
}
static void emitInterfaceDecl(const Record &interfaceDef, raw_ostream &os) {
OpInterface interface(&interfaceDef);
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
emitConceptDecl(interfaceDef, os);
emitModelDecl(interfaceDef, os);
os << "};\n} // end namespace detail\n";
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
"public:\n"
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
interfaceName, interfaceName, interfaceTraitsName);
// Insert the method declarations.
for (auto &method : interface.getMethods()) {
os << " " << method.getReturnType() << " ";
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
os << ";\n";
}
os << "};\n";
}
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("Operation Interface Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
for (const auto *def : defs)
emitInterfaceDecl(*def, os);
return false;
}
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDecls("gen-op-interface-decls",
"Generate op interface declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDecls(records, os);
});
// Registers the operation interface generator to mlir-tblgen.
static mlir::GenRegistration
genInterfaceDefs("gen-op-interface-defs",
"Generate op interface definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitInterfaceDefs(records, os);
});