mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 01:07:04 +08:00
Add support for generating operation interfaces from the ODS framework.
Operation interfaces generally require a bit of boilerplate code to connect all of the pieces together. This cl introduces mechanisms in the ODS to allow for generating operation interfaces via the 'OpInterface' class.
Providing a definition of the `OpInterface` class will auto-generate the c++
classes for the interface. An `OpInterface` includes a name, for the c++ class,
along with a list of interface methods. There are two types of methods that can be used with an interface, `InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the same core components, with the distinction that `StaticInterfaceMethod` models a static method on the derived operation.
An `InterfaceMethod` is comprised of the following components:
* ReturnType
- A string corresponding to the c++ return type of the method.
* MethodName
- A string corresponding to the desired name of the method.
* Arguments
- A dag of strings that correspond to a c++ type and variable name
respectively.
* MethodBody (Optional)
- An optional explicit implementation of the interface method.
def MyInterface : OpInterface<"MyInterface"> {
let methods = [
// A simple non-static method with no inputs.
InterfaceMethod<"unsigned", "foo">,
// A new non-static method accepting an input argument.
InterfaceMethod<"Value *", "bar", (ins "unsigned":$i)>,
// Query a static property of the derived operation.
StaticInterfaceMethod<"unsigned", "fooStatic">,
// Provide the definition of a static interface method.
// Note: `ConcreteOp` corresponds to the derived operation typename.
StaticInterfaceMethod<"Operation *", "create",
(ins "OpBuilder &":$builder, "Location":$loc), [{
return builder.create<ConcreteOp>(loc);
}]>,
// Provide a definition of the non-static method.
// Note: `op` corresponds to the derived operation variable.
InterfaceMethod<"unsigned", "getNumInputsAndOutputs", (ins), [{
return op.getNumInputs() + op.getNumOutputs();
}]>,
];
PiperOrigin-RevId: 264754898
This commit is contained in:
committed by
A. Unique TensorFlower
parent
85bc8655f0
commit
b9377d7ec6
@@ -171,3 +171,21 @@ Operation *op = ...;
|
||||
if (ExampleOpInterface example = dyn_cast<ExampleOpInterface>(op))
|
||||
llvm::errs() << "num inputs = " << example.getNumInputs() << "\n";
|
||||
```
|
||||
|
||||
#### Utilizing the ODS Framework
|
||||
|
||||
Operation interfaces require a bit of boiler plate to connect all of the pieces
|
||||
together. The ODS(Operation Definition Specification) framework provides
|
||||
simplified mechanisms for
|
||||
[defining interfaces](OpDefinitions.md#operation-interfaces).
|
||||
|
||||
As an example, using the ODS framework would allow for defining the example
|
||||
interface above as:
|
||||
|
||||
```tablegen
|
||||
def ExampleOpInterface : OpInterface<"ExampleOpInterface"> {
|
||||
let methods = [
|
||||
InterfaceMethod<"unsigned", "getNumInputs">,
|
||||
];
|
||||
}
|
||||
```
|
||||
|
||||
@@ -281,10 +281,81 @@ same operation.
|
||||
Traits are operation properties that affect syntax or semantics. MLIR C++
|
||||
models various traits in the `mlir::OpTrait` namespace.
|
||||
|
||||
Both operation traits and constraints involving multiple
|
||||
operands/attributes/results are provided as the second template parameter to the
|
||||
`Op` class. They should be deriving from the `OpTrait` class. See
|
||||
[Constraints](#constraints) for more information.
|
||||
Both operation traits, [interfaces](#operation-interfaces), and constraints
|
||||
involving multiple operands/attributes/results are provided as the second
|
||||
template parameter to the `Op` class. They should be deriving from the `OpTrait`
|
||||
class. See [Constraints](#constraints) for more information.
|
||||
|
||||
### Operation interfaces
|
||||
|
||||
[Operation interfaces](Interfaces.md#operation-interfaces) are a mechanism by
|
||||
which to opaquely call methods and access information on an *Op instance,
|
||||
without knowing the exact operation type. Operation interfaces defined in C++
|
||||
can be accessed in the ODS framework via the `OpInterfaceTrait` class. Aside
|
||||
from using pre-existing interfaces in the C++ API, the ODS framework also
|
||||
provides a simplified mechanism for defining such interfaces; that removes much
|
||||
of the boilerplate necessary.
|
||||
|
||||
Providing a definition of the `OpInterface` class will auto-generate the C++
|
||||
classes for the interface. An `OpInterface` includes a name, for the C++ class,
|
||||
along with a list of interface methods.
|
||||
|
||||
```tablegen
|
||||
def MyInterface : OpInterface<"MyInterface"> {
|
||||
let methods = [...];
|
||||
}
|
||||
```
|
||||
|
||||
There are two types of methods that can be used with an interface,
|
||||
`InterfaceMethod` and `StaticInterfaceMethod`. They are both comprised of the
|
||||
same core components, with the distinction that `StaticInterfaceMethod` models a
|
||||
static method on the derived operation.
|
||||
|
||||
An `InterfaceMethod` is comprised of the following components:
|
||||
|
||||
* ReturnType
|
||||
- A string corresponding to the C++ return type of the method.
|
||||
* MethodName
|
||||
- A string corresponding to the desired name of the method.
|
||||
* Arguments (Optional)
|
||||
- A dag of strings that correspond to a C++ type and variable name
|
||||
respectively.
|
||||
* MethodBody (Optional)
|
||||
- An optional explicit implementation of the interface method.
|
||||
- `ConcreteOp` is an implicitly defined typename that can be used to refer
|
||||
to the type of the derived operation currently being operated on.
|
||||
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
|
||||
used to refer to an instance of the derived operation.
|
||||
|
||||
Examples:
|
||||
|
||||
```tablegen
|
||||
def MyInterface : OpInterface<"MyInterface"> {
|
||||
let methods = [
|
||||
// A simple non-static method with no inputs.
|
||||
InterfaceMethod<"unsigned", "foo">,
|
||||
|
||||
// A new non-static method accepting an input argument.
|
||||
InterfaceMethod<"Value *", "bar", (ins "unsigned":$i)>,
|
||||
|
||||
// Query a static property of the derived operation.
|
||||
StaticInterfaceMethod<"unsigned", "fooStatic">,
|
||||
|
||||
// Provide the definition of a static interface method.
|
||||
// Note: `ConcreteOp` corresponds to the derived operation typename.
|
||||
StaticInterfaceMethod<"Operation *", "create",
|
||||
(ins "OpBuilder &":$builder, "Location":$loc), [{
|
||||
return builder.create<ConcreteOp>(loc);
|
||||
}]>,
|
||||
|
||||
// Provide a definition of the non-static method.
|
||||
// Note: `op` corresponds to the derived operation variable.
|
||||
InterfaceMethod<"unsigned", "getNumInputsAndOutputs", (ins), [{
|
||||
return op.getNumInputs() + op.getNumOutputs();
|
||||
}]>,
|
||||
];
|
||||
}
|
||||
```
|
||||
|
||||
### Custom builder methods
|
||||
|
||||
|
||||
@@ -5,4 +5,6 @@ add_public_tablegen_target(MLIRLinalgOpsIncGen)
|
||||
set(LLVM_TARGET_DEFINITIONS LinalgLibraryOps.td)
|
||||
mlir_tablegen(LinalgLibraryOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(LinalgLibraryOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(LinalgLibraryOpInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(LinalgLibraryOpInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRLinalgLibraryOpsIncGen)
|
||||
|
||||
@@ -70,7 +70,49 @@ def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
|
||||
|
||||
// The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp'
|
||||
// interface.
|
||||
def LinalgLibraryInterface : NativeOpInterface<"LinalgOp">;
|
||||
def LinalgLibraryInterface : OpInterface<"LinalgOp"> {
|
||||
let methods = [
|
||||
/// Query the number of inputs and outputs from the operation.
|
||||
InterfaceMethod<"unsigned", "getNumInputs">,
|
||||
InterfaceMethod<"unsigned", "getNumOutputs">,
|
||||
InterfaceMethod<"unsigned", "getNumInputsAndOutputs">,
|
||||
InterfaceMethod<"Operation::operand_range", "getInputs">,
|
||||
InterfaceMethod<"Operation::operand_range", "getOutputs">,
|
||||
InterfaceMethod<"Operation::operand_range", "getInputsAndOutputs">,
|
||||
|
||||
/// Query the number of each type of loop.
|
||||
InterfaceMethod<"unsigned", "getNumParallelLoops">,
|
||||
InterfaceMethod<"unsigned", "getNumReductionLoops">,
|
||||
InterfaceMethod<"unsigned", "getNumWindowLoops">,
|
||||
InterfaceMethod<"unsigned", "getNumLoops", (ins), [{
|
||||
return op.getNumParallelLoops() + op.getNumReductionLoops() +
|
||||
op.getNumWindowLoops();
|
||||
}]>,
|
||||
|
||||
/// Get a specific input/output at the given index.
|
||||
InterfaceMethod<"Value *", "getInput", (ins "unsigned":$i)>,
|
||||
InterfaceMethod<"Value *", "getOutput", (ins "unsigned":$i)>,
|
||||
|
||||
/// Get the index of the given value, or None if the value is not an input.
|
||||
InterfaceMethod<"llvm::Optional<unsigned>", "getIndexOfInput",
|
||||
(ins "Value *":$view)>,
|
||||
InterfaceMethod<"llvm::Optional<unsigned>", "getIndexOfOutput",
|
||||
(ins "Value *":$view)>,
|
||||
|
||||
/// Get the view type of the input/output at the given index.
|
||||
InterfaceMethod<"ViewType", "getInputViewType", (ins "unsigned":$i)>,
|
||||
InterfaceMethod<"ViewType", "getOutputViewType", (ins "unsigned":$i)>,
|
||||
|
||||
/// Create an operation with the given location and operands.
|
||||
StaticInterfaceMethod<"Operation *", "create",
|
||||
(ins "OpBuilder &":$builder, "Location":$loc,
|
||||
"ArrayRef<Value *>":$operands,
|
||||
"ArrayRef<NamedAttribute>":$attributes), [{
|
||||
return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
|
||||
attributes);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
// Base Tablegen class for Linalg ops.
|
||||
// Linalg ops that correspond to library calls operate on linalg::View as their
|
||||
|
||||
@@ -73,147 +73,7 @@ std::string generateLibraryCallName(Operation *op);
|
||||
/// Only permutation maps are currently supported.
|
||||
SmallVector<AffineMap, 4> loopToOperandRangesMaps(Operation *op);
|
||||
|
||||
namespace detail {
|
||||
struct LinalgOpInterfaceTraits {
|
||||
struct Concept {
|
||||
virtual ~Concept() = default;
|
||||
virtual unsigned getNumInputs(Operation *op) = 0;
|
||||
virtual unsigned getNumOutputs(Operation *op) = 0;
|
||||
virtual unsigned getNumInputsAndOutputs(Operation *op) = 0;
|
||||
virtual unsigned getNumParallelLoops(Operation *op) = 0;
|
||||
virtual unsigned getNumReductionLoops(Operation *op) = 0;
|
||||
virtual unsigned getNumWindowLoops(Operation *op) = 0;
|
||||
virtual Value *getInput(Operation *op, unsigned i) = 0;
|
||||
virtual llvm::Optional<unsigned> getIndexOfInput(Operation *op,
|
||||
Value *view) = 0;
|
||||
virtual ViewType getInputViewType(Operation *op, unsigned i) = 0;
|
||||
virtual Operation::operand_range getInputs(Operation *op) = 0;
|
||||
virtual Value *getOutput(Operation *op, unsigned i) = 0;
|
||||
virtual llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
|
||||
Value *view) = 0;
|
||||
virtual ViewType getOutputViewType(Operation *op, unsigned i) = 0;
|
||||
virtual Operation::operand_range getOutputs(Operation *op) = 0;
|
||||
virtual Operation::operand_range getInputsAndOutputs(Operation *op) = 0;
|
||||
virtual Operation *create(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<NamedAttribute> attributes) = 0;
|
||||
};
|
||||
|
||||
template <typename ConcreteOp> struct Model : public Concept {
|
||||
unsigned getNumInputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumInputs();
|
||||
}
|
||||
unsigned getNumOutputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumOutputs();
|
||||
}
|
||||
unsigned getNumInputsAndOutputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumInputsAndOutputs();
|
||||
}
|
||||
unsigned getNumParallelLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumParallelLoops();
|
||||
}
|
||||
unsigned getNumReductionLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumReductionLoops();
|
||||
}
|
||||
unsigned getNumWindowLoops(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getNumWindowLoops();
|
||||
}
|
||||
Value *getInput(Operation *op, unsigned i) override {
|
||||
return cast<ConcreteOp>(op).getInput(i);
|
||||
}
|
||||
llvm::Optional<unsigned> getIndexOfInput(Operation *op,
|
||||
Value *view) override {
|
||||
return cast<ConcreteOp>(op).getIndexOfInput(view);
|
||||
}
|
||||
ViewType getInputViewType(Operation *op, unsigned i) override {
|
||||
return cast<ConcreteOp>(op).getInputViewType(i);
|
||||
}
|
||||
Operation::operand_range getInputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getInputs();
|
||||
}
|
||||
Value *getOutput(Operation *op, unsigned i) override {
|
||||
return cast<ConcreteOp>(op).getOutput(i);
|
||||
}
|
||||
llvm::Optional<unsigned> getIndexOfOutput(Operation *op,
|
||||
Value *view) override {
|
||||
return cast<ConcreteOp>(op).getIndexOfOutput(view);
|
||||
}
|
||||
ViewType getOutputViewType(Operation *op, unsigned i) override {
|
||||
return cast<ConcreteOp>(op).getOutputViewType(i);
|
||||
}
|
||||
Operation::operand_range getOutputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getOutputs();
|
||||
}
|
||||
Operation::operand_range getInputsAndOutputs(Operation *op) override {
|
||||
return cast<ConcreteOp>(op).getInputsAndOutputs();
|
||||
}
|
||||
Operation *create(OpBuilder &builder, Location loc,
|
||||
ArrayRef<Value *> operands,
|
||||
ArrayRef<NamedAttribute> attributes) override {
|
||||
return builder.create<ConcreteOp>(loc, ArrayRef<Type>{}, operands,
|
||||
attributes);
|
||||
}
|
||||
};
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// A LinalgOp behaves like a base class for the Linalg operations that are
|
||||
/// defined in LinalgLibraryOps.td. The implementation does not use inheritance
|
||||
/// directly. Instead, a LinalgOp directly derives from Op, hides the `classof`
|
||||
/// method and dispatches to the appropriate LinalgLibraryOp.
|
||||
/// This allows writing generic passes, like tiling, for all current and future
|
||||
/// LinalgOps without requiring templating and dispatch in multiple places.
|
||||
class LinalgOp : public OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits> {
|
||||
public:
|
||||
using OpInterface<LinalgOp, detail::LinalgOpInterfaceTraits>::OpInterface;
|
||||
|
||||
unsigned getNumParallelLoops() {
|
||||
return getImpl()->getNumParallelLoops(getOperation());
|
||||
}
|
||||
unsigned getNumReductionLoops() {
|
||||
return getImpl()->getNumReductionLoops(getOperation());
|
||||
}
|
||||
unsigned getNumWindowLoops() {
|
||||
return getImpl()->getNumWindowLoops(getOperation());
|
||||
}
|
||||
unsigned getNumLoops() {
|
||||
return getNumParallelLoops() + getNumReductionLoops() + getNumWindowLoops();
|
||||
}
|
||||
unsigned getNumInputs() { return getImpl()->getNumInputs(getOperation()); }
|
||||
unsigned getNumOutputs() { return getImpl()->getNumOutputs(getOperation()); }
|
||||
unsigned getNumInputsAndOutputs() {
|
||||
return getImpl()->getNumInputsAndOutputs(getOperation());
|
||||
}
|
||||
Value *getInput(unsigned i) { return getImpl()->getInput(getOperation(), i); }
|
||||
llvm::Optional<unsigned> getIndexOfInput(Value *view) {
|
||||
return getImpl()->getIndexOfInput(getOperation(), view);
|
||||
}
|
||||
ViewType getInputViewType(unsigned i) {
|
||||
return getImpl()->getInputViewType(getOperation(), i);
|
||||
}
|
||||
Operation::operand_range getInputs() {
|
||||
return getImpl()->getInputs(getOperation());
|
||||
}
|
||||
Value *getOutput(unsigned i) {
|
||||
return getImpl()->getOutput(getOperation(), i);
|
||||
}
|
||||
llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
|
||||
return getImpl()->getIndexOfOutput(getOperation(), view);
|
||||
}
|
||||
ViewType getOutputViewType(unsigned i) {
|
||||
return getImpl()->getOutputViewType(getOperation(), i);
|
||||
}
|
||||
Operation::operand_range getOutputs() {
|
||||
return getImpl()->getOutputs(getOperation());
|
||||
}
|
||||
Operation::operand_range getInputsAndOutputs() {
|
||||
return getImpl()->getInputsAndOutputs(getOperation());
|
||||
}
|
||||
LinalgOp create(OpBuilder &builder, Location loc, ArrayRef<Value *> operands,
|
||||
ArrayRef<NamedAttribute> attributes) {
|
||||
return LinalgOp(getImpl()->create(builder, loc, operands, attributes));
|
||||
}
|
||||
};
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOpInterfaces.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h.inc"
|
||||
|
||||
@@ -1089,22 +1089,52 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
|
||||
// OpInterface definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// NativeOpInterface corresponds to a specific 'OpInterface' class defined in
|
||||
// Marker used to identify the argument list for an op or interface method.
|
||||
def ins;
|
||||
|
||||
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
|
||||
// C++. The purpose to wrap around C++ symbol string with this class is to make
|
||||
// interfaces specified for ops in TableGen less alien and more integrated.
|
||||
class NativeOpInterface<string prop> : NativeOpTrait<""> {
|
||||
// TODO(riverriddle) Remove when operation interfaces have their own trait
|
||||
// subclass.
|
||||
let trait = prop # "::Trait";
|
||||
class OpInterfaceTrait<string name> : NativeOpTrait<""> {
|
||||
let trait = name # "::Trait";
|
||||
}
|
||||
|
||||
// This class represents a single, optionally static, interface method.
|
||||
// Note: non-static interface methods have an implicit 'op' parameter
|
||||
// corresponding to an instance of the derived operation.
|
||||
class InterfaceMethod<string retTy, string methodName,
|
||||
dag args = (ins), code methodBody = [{}]> {
|
||||
/// The name of the interface method.
|
||||
string name = methodName;
|
||||
|
||||
/// The c++ type-name of the return type.
|
||||
string returnType = retTy;
|
||||
|
||||
/// A dag of string that correspond to the arguments of the method.
|
||||
dag arguments = args;
|
||||
|
||||
/// An optional body to the method.
|
||||
code body = methodBody;
|
||||
}
|
||||
|
||||
// This class represents a single static interface method.
|
||||
class StaticInterfaceMethod<string retTy, string methodName,
|
||||
dag args = (ins), code methodBody = [{}]>
|
||||
: InterfaceMethod<retTy, methodName, args, methodBody>;
|
||||
|
||||
// OpInterface represents an interface regarding an op.
|
||||
class OpInterface<string name> : OpInterfaceTrait<name> {
|
||||
// The name given to the c++ interface class.
|
||||
string cppClassName = name;
|
||||
|
||||
/// The list of methods defined by this interface.
|
||||
list<InterfaceMethod> methods = [];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Marker used to identify the argument list for an op.
|
||||
def ins;
|
||||
|
||||
// Marker used to identify the result list for an op.
|
||||
def outs;
|
||||
|
||||
|
||||
@@ -829,6 +829,8 @@ llvm::raw_ostream &mlir::linalg::operator<<(llvm::raw_ostream &os,
|
||||
namespace mlir {
|
||||
namespace linalg {
|
||||
|
||||
#include "mlir/Linalg/IR/LinalgLibraryOpInterfaces.cpp.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ add_tablegen(mlir-tblgen MLIR
|
||||
mlir-tblgen.cpp
|
||||
OpDefinitionsGen.cpp
|
||||
OpDocGen.cpp
|
||||
OpInterfacesGen.cpp
|
||||
ReferenceImplGen.cpp
|
||||
RewriterGen.cpp
|
||||
SPIRVUtilsGen.cpp
|
||||
|
||||
257
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Normal file
257
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
Normal file
@@ -0,0 +1,257 @@
|
||||
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// OpInterfacesGen generates definitions for operation interfaces.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// This struct represents a single method argument.
|
||||
struct MethodArgument {
|
||||
StringRef type, name;
|
||||
};
|
||||
|
||||
// Wrapper class around a single interface method.
|
||||
class OpInterfaceMethod {
|
||||
public:
|
||||
explicit OpInterfaceMethod(const llvm::Record *def) : def(def) {
|
||||
llvm::DagInit *args = def->getValueAsDag("arguments");
|
||||
for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
|
||||
arguments.push_back(
|
||||
{llvm::cast<llvm::StringInit>(args->getArg(i))->getValue(),
|
||||
args->getArgNameStr(i)});
|
||||
}
|
||||
}
|
||||
|
||||
// Return the return type of this method.
|
||||
StringRef getReturnType() const {
|
||||
return def->getValueAsString("returnType");
|
||||
}
|
||||
|
||||
// Return the name of this method.
|
||||
StringRef getName() const { return def->getValueAsString("name"); }
|
||||
|
||||
// Return if this method is static.
|
||||
bool isStatic() const { return def->isSubClassOf("StaticInterfaceMethod"); }
|
||||
|
||||
// Return the body for this method if it has one.
|
||||
llvm::Optional<StringRef> getBody() const {
|
||||
auto value = def->getValueAsString("body");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Arguments.
|
||||
ArrayRef<MethodArgument> getArguments() const { return arguments; }
|
||||
bool arg_empty() const { return arguments.empty(); }
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this method.
|
||||
const llvm::Record *def;
|
||||
|
||||
// The arguments of this method.
|
||||
SmallVector<MethodArgument, 2> arguments;
|
||||
};
|
||||
|
||||
// Wrapper class with helper methods for accessing OpInterfaces defined in
|
||||
// TableGen.
|
||||
class OpInterface {
|
||||
public:
|
||||
explicit OpInterface(const llvm::Record *def) : def(def) {
|
||||
auto *listInit = dyn_cast<llvm::ListInit>(def->getValueInit("methods"));
|
||||
for (llvm::Init *init : listInit->getValues())
|
||||
methods.emplace_back(cast<llvm::DefInit>(init)->getDef());
|
||||
}
|
||||
|
||||
// Return the name of this interface.
|
||||
StringRef getName() const { return def->getValueAsString("cppClassName"); }
|
||||
|
||||
// Return the methods of this interface.
|
||||
ArrayRef<OpInterfaceMethod> getMethods() const { return methods; }
|
||||
|
||||
protected:
|
||||
// The TableGen definition of this interface.
|
||||
const llvm::Record *def;
|
||||
|
||||
// The methods of this interface.
|
||||
SmallVector<OpInterfaceMethod, 8> methods;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
// Emit the method name and argument list for the given method. If
|
||||
// 'addOperationArg' is true, then an Operation* argument is added to the
|
||||
// beginning of the argument list.
|
||||
static void emitMethodNameAndArgs(const OpInterfaceMethod &method,
|
||||
raw_ostream &os, bool addOperationArg) {
|
||||
os << method.getName() << '(';
|
||||
if (addOperationArg)
|
||||
os << "Operation *tablegen_opaque_op" << (method.arg_empty() ? "" : ", ");
|
||||
interleaveComma(method.getArguments(), os, [&](const MethodArgument &arg) {
|
||||
os << arg.type << " " << arg.name;
|
||||
});
|
||||
os << ')';
|
||||
}
|
||||
|
||||
static void emitInterfaceDef(const Record &interfaceDef, raw_ostream &os) {
|
||||
OpInterface interface(&interfaceDef);
|
||||
StringRef interfaceName = interface.getName();
|
||||
|
||||
// Insert the method definitions.
|
||||
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
|
||||
for (Init *init : listInit->getValues()) {
|
||||
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
|
||||
os << method.getReturnType() << " " << interfaceName << "::";
|
||||
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
||||
|
||||
// Forward to the method on the concrete operation type.
|
||||
os << " {\n return getImpl()->" << method.getName() << '(';
|
||||
if (!method.isStatic())
|
||||
os << "getOperation()" << (method.arg_empty() ? "" : ", ");
|
||||
interleaveComma(method.getArguments(), os,
|
||||
[&](const MethodArgument &arg) { os << arg.name; });
|
||||
os << ");\n }\n";
|
||||
}
|
||||
}
|
||||
|
||||
static bool emitInterfaceDefs(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Operation Interface Definitions", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
|
||||
for (const auto *def : defs)
|
||||
emitInterfaceDef(*def, os);
|
||||
return false;
|
||||
}
|
||||
|
||||
static void emitConceptDecl(const Record &interfaceDef, raw_ostream &os) {
|
||||
os << " class Concept {\n"
|
||||
<< " public:\n"
|
||||
<< " virtual ~Concept() = default;\n";
|
||||
|
||||
// Insert each of the virtual methods.
|
||||
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
|
||||
for (Init *init : listInit->getValues()) {
|
||||
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
|
||||
|
||||
// In the concept, all methods are pure virtual.
|
||||
os << " virtual " << method.getReturnType() << " ";
|
||||
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
||||
os << " = 0;\n";
|
||||
}
|
||||
os << " };\n";
|
||||
}
|
||||
|
||||
static void emitModelDecl(const Record &interfaceDef, raw_ostream &os) {
|
||||
os << " template<typename ConcreteOp>\n";
|
||||
os << " class Model : public Concept {\npublic:\n";
|
||||
|
||||
// Insert each of the virtual method overrides.
|
||||
auto *listInit = dyn_cast<ListInit>(interfaceDef.getValueInit("methods"));
|
||||
for (Init *init : listInit->getValues()) {
|
||||
OpInterfaceMethod method(cast<DefInit>(init)->getDef());
|
||||
os << " " << method.getReturnType() << " ";
|
||||
emitMethodNameAndArgs(method, os, /*addOperationArg=*/!method.isStatic());
|
||||
os << " final {\n";
|
||||
|
||||
// Provide a definition of the concrete op if this is non static.
|
||||
if (!method.isStatic()) {
|
||||
os << " auto op = llvm::cast<ConcreteOp>(tablegen_opaque_op);\n"
|
||||
<< " (void)op;\n";
|
||||
}
|
||||
|
||||
// Check for a provided body to the function.
|
||||
if (auto body = method.getBody()) {
|
||||
os << body << "\n }\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Forward to the method on the concrete operation type.
|
||||
os << " return " << (method.isStatic() ? "ConcreteOp::" : "op.");
|
||||
|
||||
// Add the arguments to the call.
|
||||
os << method.getName() << '(';
|
||||
interleaveComma(method.getArguments(), os,
|
||||
[&](const MethodArgument &arg) { os << arg.name; });
|
||||
os << ");\n }\n";
|
||||
}
|
||||
os << " };\n";
|
||||
}
|
||||
|
||||
static void emitInterfaceDecl(const Record &interfaceDef, raw_ostream &os) {
|
||||
OpInterface interface(&interfaceDef);
|
||||
StringRef interfaceName = interface.getName();
|
||||
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
|
||||
|
||||
// Emit the traits struct containing the concept and model declarations.
|
||||
os << "namespace detail {\n"
|
||||
<< "struct " << interfaceTraitsName << " {\n";
|
||||
emitConceptDecl(interfaceDef, os);
|
||||
emitModelDecl(interfaceDef, os);
|
||||
os << "};\n} // end namespace detail\n";
|
||||
|
||||
// Emit the main interface class declaration.
|
||||
os << llvm::formatv("class {0} : public OpInterface<{1}, detail::{2}> {\n"
|
||||
"public:\n"
|
||||
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
|
||||
interfaceName, interfaceName, interfaceTraitsName);
|
||||
|
||||
// Insert the method declarations.
|
||||
for (auto &method : interface.getMethods()) {
|
||||
os << " " << method.getReturnType() << " ";
|
||||
emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
|
||||
os << ";\n";
|
||||
}
|
||||
os << "};\n";
|
||||
}
|
||||
|
||||
static bool emitInterfaceDecls(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("Operation Interface Declarations", os);
|
||||
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("OpInterface");
|
||||
for (const auto *def : defs)
|
||||
emitInterfaceDecl(*def, os);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Registers the operation interface generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genInterfaceDecls("gen-op-interface-decls",
|
||||
"Generate op interface declarations",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitInterfaceDecls(records, os);
|
||||
});
|
||||
|
||||
// Registers the operation interface generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genInterfaceDefs("gen-op-interface-defs",
|
||||
"Generate op interface definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitInterfaceDefs(records, os);
|
||||
});
|
||||
Reference in New Issue
Block a user