mirror of
https://github.com/intel/llvm.git
synced 2026-01-24 08:30:34 +08:00
[MLIR] Add support for defining Types in tblgen
Adds a TypeDef class to OpBase and backing generation code. Allows one to define the Type, its parameters, and printer/parser methods in ODS. Can generate the Type C++ class, accessors, storage class, per-parameter custom allocators (for the storage constructor), and documentation. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D86904
This commit is contained in:
@@ -9,6 +9,8 @@ function(add_mlir_dialect dialect dialect_namespace)
|
||||
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
|
||||
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
|
||||
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs)
|
||||
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
|
||||
add_public_tablegen_target(MLIR${dialect}IncGen)
|
||||
add_dependencies(mlir-headers MLIR${dialect}IncGen)
|
||||
|
||||
@@ -2364,4 +2364,116 @@ def location;
|
||||
// so to replace the matched DAG with an existing SSA value.
|
||||
def replaceWithValue;
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Data type generation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Define a new type belonging to a dialect and called 'name'.
|
||||
class TypeDef<Dialect owningdialect, string name> {
|
||||
Dialect dialect = owningdialect;
|
||||
string cppClassName = name # "Type";
|
||||
|
||||
// Short summary of the type.
|
||||
string summary = ?;
|
||||
// The longer description of this type.
|
||||
string description = ?;
|
||||
|
||||
// Name of storage class to generate or use.
|
||||
string storageClass = name # "TypeStorage";
|
||||
// Namespace (withing dialect c++ namespace) in which the storage class
|
||||
// resides.
|
||||
string storageNamespace = "detail";
|
||||
// Specify if the storage class is to be generated.
|
||||
bit genStorageClass = 1;
|
||||
// Specify that the generated storage class has a constructor which is written
|
||||
// in C++.
|
||||
bit hasStorageCustomConstructor = 0;
|
||||
|
||||
// The list of parameters for this type. Parameters will become both
|
||||
// parameters to the get() method and storage class member variables.
|
||||
//
|
||||
// The format of this dag is:
|
||||
// (ins
|
||||
// "<c++ type>":$param1Name,
|
||||
// "<c++ type>":$param2Name,
|
||||
// TypeParameter<"c++ type", "param description">:$param3Name)
|
||||
// TypeParameters (or more likely one of their subclasses) are required to add
|
||||
// more information about the parameter, specifically:
|
||||
// - Documentation
|
||||
// - Code to allocate the parameter (if allocation is needed in the storage
|
||||
// class constructor)
|
||||
//
|
||||
// For example:
|
||||
// (ins
|
||||
// "int":$width,
|
||||
// ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
|
||||
//
|
||||
// (ArrayRefParameter is a subclass of TypeParameter which has allocation code
|
||||
// for re-allocating ArrayRefs. It is defined below.)
|
||||
dag parameters = (ins);
|
||||
|
||||
// Use the lowercased name as the keyword for parsing/printing. Specify only
|
||||
// if you want tblgen to generate declarations and/or definitions of
|
||||
// printer/parser for this type.
|
||||
string mnemonic = ?;
|
||||
// If 'mnemonic' specified,
|
||||
// If null, generate just the declarations.
|
||||
// If a non-empty code block, just use that code as the definition code.
|
||||
// Error if an empty code block.
|
||||
code printer = ?;
|
||||
code parser = ?;
|
||||
|
||||
// If set, generate accessors for each Type parameter.
|
||||
bit genAccessors = 1;
|
||||
// Generate the verifyConstructionInvariants declaration and getChecked
|
||||
// method.
|
||||
bit genVerifyInvariantsDecl = 0;
|
||||
// Extra code to include in the class declaration.
|
||||
code extraClassDeclaration = [{}];
|
||||
}
|
||||
|
||||
// 'Parameters' should be subclasses of this or simple strings (which is a
|
||||
// shorthand for TypeParameter<"C++Type">).
|
||||
class TypeParameter<string type, string desc> {
|
||||
// Custom memory allocation code for storage constructor.
|
||||
code allocator = ?;
|
||||
// The C++ type of this parameter.
|
||||
string cppType = type;
|
||||
// A description of this parameter.
|
||||
string description = desc;
|
||||
// The format string for the asm syntax (documentation only).
|
||||
string syntax = ?;
|
||||
}
|
||||
|
||||
// For StringRefs, which require allocation.
|
||||
class StringRefParameter<string desc> :
|
||||
TypeParameter<"::llvm::StringRef", desc> {
|
||||
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
|
||||
}
|
||||
|
||||
// For standard ArrayRefs, which require allocation.
|
||||
class ArrayRefParameter<string arrayOf, string desc> :
|
||||
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
|
||||
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
|
||||
}
|
||||
|
||||
// For classes which require allocation and have their own allocateInto method.
|
||||
class SelfAllocationParameter<string type, string desc> :
|
||||
TypeParameter<type, desc> {
|
||||
let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
|
||||
}
|
||||
|
||||
// For ArrayRefs which contain things which allocate themselves.
|
||||
class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
|
||||
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
|
||||
let allocator = [{
|
||||
llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
|
||||
for (size_t i = 0, e = $_self.size(); i < e; ++i)
|
||||
tmpFields.push_back($_self[i].allocateInto($_allocator));
|
||||
$_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields));
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // OP_BASE
|
||||
|
||||
135
mlir/include/mlir/TableGen/TypeDef.h
Normal file
135
mlir/include/mlir/TableGen/TypeDef.h
Normal file
@@ -0,0 +1,135 @@
|
||||
//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// TypeDef wrapper to simplify using TableGen Record defining a MLIR type.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_TYPEDEF_H
|
||||
#define MLIR_TABLEGEN_TYPEDEF_H
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Dialect.h"
|
||||
|
||||
namespace llvm {
|
||||
class Record;
|
||||
class DagInit;
|
||||
class SMLoc;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
class TypeParameter;
|
||||
|
||||
/// Wrapper class that contains a TableGen TypeDef's record and provides helper
|
||||
/// methods for accessing them.
|
||||
class TypeDef {
|
||||
public:
|
||||
explicit TypeDef(const llvm::Record *def) : def(def) {}
|
||||
|
||||
// Get the dialect for which this type belongs.
|
||||
Dialect getDialect() const;
|
||||
|
||||
// Returns the name of this TypeDef record.
|
||||
StringRef getName() const;
|
||||
|
||||
// Query functions for the documentation of the operator.
|
||||
bool hasDescription() const;
|
||||
StringRef getDescription() const;
|
||||
bool hasSummary() const;
|
||||
StringRef getSummary() const;
|
||||
|
||||
// Returns the name of the C++ class to generate.
|
||||
StringRef getCppClassName() const;
|
||||
|
||||
// Returns the name of the storage class for this type.
|
||||
StringRef getStorageClassName() const;
|
||||
|
||||
// Returns the C++ namespace for this types storage class.
|
||||
StringRef getStorageNamespace() const;
|
||||
|
||||
// Returns true if we should generate the storage class.
|
||||
bool genStorageClass() const;
|
||||
|
||||
// Indicates whether or not to generate the storage class constructor.
|
||||
bool hasStorageCustomConstructor() const;
|
||||
|
||||
// Fill a list with this types parameters. See TypeDef in OpBase.td for
|
||||
// documentation of parameter usage.
|
||||
void getParameters(SmallVectorImpl<TypeParameter> &) const;
|
||||
// Return the number of type parameters
|
||||
unsigned getNumParameters() const;
|
||||
|
||||
// Return the keyword/mnemonic to use in the printer/parser methods if we are
|
||||
// supposed to auto-generate them.
|
||||
Optional<StringRef> getMnemonic() const;
|
||||
|
||||
// Returns the code to use as the types printer method. If not specified,
|
||||
// return a non-value. Otherwise, return the contents of that code block.
|
||||
Optional<StringRef> getPrinterCode() const;
|
||||
|
||||
// Returns the code to use as the types parser method. If not specified,
|
||||
// return a non-value. Otherwise, return the contents of that code block.
|
||||
Optional<StringRef> getParserCode() const;
|
||||
|
||||
// Returns true if the accessors based on the types parameters should be
|
||||
// generated.
|
||||
bool genAccessors() const;
|
||||
|
||||
// Return true if we need to generate the verifyConstructionInvariants
|
||||
// declaration and getChecked method.
|
||||
bool genVerifyInvariantsDecl() const;
|
||||
|
||||
// Returns the dialects extra class declaration code.
|
||||
Optional<StringRef> getExtraDecls() const;
|
||||
|
||||
// Get the code location (for error printing).
|
||||
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||
|
||||
// Returns whether two TypeDefs are equal by checking the equality of the
|
||||
// underlying record.
|
||||
bool operator==(const TypeDef &other) const;
|
||||
|
||||
// Compares two TypeDefs by comparing the names of the dialects.
|
||||
bool operator<(const TypeDef &other) const;
|
||||
|
||||
// Returns whether the TypeDef is defined.
|
||||
operator bool() const { return def != nullptr; }
|
||||
|
||||
private:
|
||||
const llvm::Record *def;
|
||||
};
|
||||
|
||||
// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
|
||||
// to parameterize them.
|
||||
class TypeParameter {
|
||||
public:
|
||||
explicit TypeParameter(const llvm::DagInit *def, unsigned num)
|
||||
: def(def), num(num) {}
|
||||
|
||||
// Get the parameter name.
|
||||
StringRef getName() const;
|
||||
// If specified, get the custom allocator code for this parameter.
|
||||
llvm::Optional<StringRef> getAllocator() const;
|
||||
// Get the C++ type of this parameter.
|
||||
StringRef getCppType() const;
|
||||
// Get a description of this parameter for documentation purposes.
|
||||
llvm::Optional<StringRef> getDescription() const;
|
||||
// Get the assembly syntax documentation.
|
||||
StringRef getSyntax() const;
|
||||
|
||||
private:
|
||||
const llvm::DagInit *def;
|
||||
const unsigned num;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_TYPEDEF_H
|
||||
@@ -25,6 +25,7 @@ llvm_add_library(MLIRTableGen STATIC
|
||||
SideEffects.cpp
|
||||
Successor.cpp
|
||||
Type.cpp
|
||||
TypeDef.cpp
|
||||
|
||||
DISABLE_LLVM_LINK_LLVM_DYLIB
|
||||
|
||||
|
||||
160
mlir/lib/TableGen/TypeDef.cpp
Normal file
160
mlir/lib/TableGen/TypeDef.cpp
Normal file
@@ -0,0 +1,160 @@
|
||||
//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/TypeDef.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
Dialect TypeDef::getDialect() const {
|
||||
auto *dialectDef =
|
||||
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
|
||||
if (dialectDef == nullptr)
|
||||
return Dialect(nullptr);
|
||||
return Dialect(dialectDef->getDef());
|
||||
}
|
||||
|
||||
StringRef TypeDef::getName() const { return def->getName(); }
|
||||
StringRef TypeDef::getCppClassName() const {
|
||||
return def->getValueAsString("cppClassName");
|
||||
}
|
||||
|
||||
bool TypeDef::hasDescription() const {
|
||||
const llvm::RecordVal *s = def->getValue("description");
|
||||
return s != nullptr && isa<llvm::StringInit>(s->getValue());
|
||||
}
|
||||
|
||||
StringRef TypeDef::getDescription() const {
|
||||
return def->getValueAsString("description");
|
||||
}
|
||||
|
||||
bool TypeDef::hasSummary() const {
|
||||
const llvm::RecordVal *s = def->getValue("summary");
|
||||
return s != nullptr && isa<llvm::StringInit>(s->getValue());
|
||||
}
|
||||
|
||||
StringRef TypeDef::getSummary() const {
|
||||
return def->getValueAsString("summary");
|
||||
}
|
||||
|
||||
StringRef TypeDef::getStorageClassName() const {
|
||||
return def->getValueAsString("storageClass");
|
||||
}
|
||||
StringRef TypeDef::getStorageNamespace() const {
|
||||
return def->getValueAsString("storageNamespace");
|
||||
}
|
||||
|
||||
bool TypeDef::genStorageClass() const {
|
||||
return def->getValueAsBit("genStorageClass");
|
||||
}
|
||||
bool TypeDef::hasStorageCustomConstructor() const {
|
||||
return def->getValueAsBit("hasStorageCustomConstructor");
|
||||
}
|
||||
void TypeDef::getParameters(SmallVectorImpl<TypeParameter> ¶meters) const {
|
||||
auto *parametersDag = def->getValueAsDag("parameters");
|
||||
if (parametersDag != nullptr) {
|
||||
size_t numParams = parametersDag->getNumArgs();
|
||||
for (unsigned i = 0; i < numParams; i++)
|
||||
parameters.push_back(TypeParameter(parametersDag, i));
|
||||
}
|
||||
}
|
||||
unsigned TypeDef::getNumParameters() const {
|
||||
auto *parametersDag = def->getValueAsDag("parameters");
|
||||
return parametersDag ? parametersDag->getNumArgs() : 0;
|
||||
}
|
||||
llvm::Optional<StringRef> TypeDef::getMnemonic() const {
|
||||
return def->getValueAsOptionalString("mnemonic");
|
||||
}
|
||||
llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
|
||||
return def->getValueAsOptionalCode("printer");
|
||||
}
|
||||
llvm::Optional<StringRef> TypeDef::getParserCode() const {
|
||||
return def->getValueAsOptionalCode("parser");
|
||||
}
|
||||
bool TypeDef::genAccessors() const {
|
||||
return def->getValueAsBit("genAccessors");
|
||||
}
|
||||
bool TypeDef::genVerifyInvariantsDecl() const {
|
||||
return def->getValueAsBit("genVerifyInvariantsDecl");
|
||||
}
|
||||
llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
|
||||
auto value = def->getValueAsString("extraClassDeclaration");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
|
||||
bool TypeDef::operator==(const TypeDef &other) const {
|
||||
return def == other.def;
|
||||
}
|
||||
|
||||
bool TypeDef::operator<(const TypeDef &other) const {
|
||||
return getName() < other.getName();
|
||||
}
|
||||
|
||||
StringRef TypeParameter::getName() const {
|
||||
return def->getArgName(num)->getValue();
|
||||
}
|
||||
llvm::Optional<StringRef> TypeParameter::getAllocator() const {
|
||||
llvm::Init *parameterType = def->getArg(num);
|
||||
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
|
||||
return llvm::Optional<StringRef>();
|
||||
|
||||
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
|
||||
if (llvm::CodeInit *ci = dyn_cast<llvm::CodeInit>(code->getValue()))
|
||||
return ci->getValue();
|
||||
if (isa<llvm::UnsetInit>(code->getValue()))
|
||||
return llvm::Optional<StringRef>();
|
||||
|
||||
llvm::PrintFatalError(
|
||||
typeParameter->getDef()->getLoc(),
|
||||
"Record `" + def->getArgName(num)->getValue() +
|
||||
"', field `printer' does not have a code initializer!");
|
||||
}
|
||||
|
||||
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
|
||||
"defs which inherit from TypeParameter\n");
|
||||
}
|
||||
StringRef TypeParameter::getCppType() const {
|
||||
auto *parameterType = def->getArg(num);
|
||||
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
|
||||
return stringType->getValue();
|
||||
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
|
||||
return typeParameter->getDef()->getValueAsString("cppType");
|
||||
llvm::PrintFatalError(
|
||||
"Parameters DAG arguments must be either strings or defs "
|
||||
"which inherit from TypeParameter\n");
|
||||
}
|
||||
llvm::Optional<StringRef> TypeParameter::getDescription() const {
|
||||
auto *parameterType = def->getArg(num);
|
||||
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
const auto *desc = typeParameter->getDef()->getValue("description");
|
||||
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
|
||||
return ci->getValue();
|
||||
}
|
||||
return llvm::Optional<StringRef>();
|
||||
}
|
||||
StringRef TypeParameter::getSyntax() const {
|
||||
auto *parameterType = def->getArg(num);
|
||||
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
|
||||
return stringType->getValue();
|
||||
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
|
||||
const auto *syntax = typeParameter->getDef()->getValue("syntax");
|
||||
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
|
||||
return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
|
||||
return getCppType();
|
||||
}
|
||||
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
|
||||
"defs which inherit from TypeParameter");
|
||||
}
|
||||
@@ -9,6 +9,12 @@ mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
|
||||
mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
|
||||
add_public_tablegen_target(MLIRTestInterfaceIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
|
||||
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
|
||||
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
|
||||
add_public_tablegen_target(MLIRTestDefIncGen)
|
||||
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS TestOps.td)
|
||||
mlir_tablegen(TestOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
|
||||
@@ -25,11 +31,13 @@ add_mlir_library(MLIRTestDialect
|
||||
TestDialect.cpp
|
||||
TestPatterns.cpp
|
||||
TestTraits.cpp
|
||||
TestTypes.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
DEPENDS
|
||||
MLIRTestInterfaceIncGen
|
||||
MLIRTestDefIncGen
|
||||
MLIRTestOpsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
|
||||
@@ -141,16 +141,23 @@ void TestDialect::initialize() {
|
||||
>();
|
||||
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
|
||||
TestInlinerInterface>();
|
||||
addTypes<TestType, TestRecursiveType>();
|
||||
addTypes<TestType, TestRecursiveType,
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "TestTypeDefs.cpp.inc"
|
||||
>();
|
||||
allowUnknownOperations();
|
||||
}
|
||||
|
||||
static Type parseTestType(DialectAsmParser &parser,
|
||||
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
|
||||
llvm::SetVector<Type> &stack) {
|
||||
StringRef typeTag;
|
||||
if (failed(parser.parseKeyword(&typeTag)))
|
||||
return Type();
|
||||
|
||||
auto genType = generatedTypeParser(ctxt, parser, typeTag);
|
||||
if (genType != Type())
|
||||
return genType;
|
||||
|
||||
if (typeTag == "test_type")
|
||||
return TestType::get(parser.getBuilder().getContext());
|
||||
|
||||
@@ -174,7 +181,7 @@ static Type parseTestType(DialectAsmParser &parser,
|
||||
if (failed(parser.parseComma()))
|
||||
return Type();
|
||||
stack.insert(rec);
|
||||
Type subtype = parseTestType(parser, stack);
|
||||
Type subtype = parseTestType(ctxt, parser, stack);
|
||||
stack.pop_back();
|
||||
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
|
||||
return Type();
|
||||
@@ -184,11 +191,13 @@ static Type parseTestType(DialectAsmParser &parser,
|
||||
|
||||
Type TestDialect::parseType(DialectAsmParser &parser) const {
|
||||
llvm::SetVector<Type> stack;
|
||||
return parseTestType(parser, stack);
|
||||
return parseTestType(getContext(), parser, stack);
|
||||
}
|
||||
|
||||
static void printTestType(Type type, DialectAsmPrinter &printer,
|
||||
llvm::SetVector<Type> &stack) {
|
||||
if (succeeded(generatedTypePrinter(type, printer)))
|
||||
return;
|
||||
if (type.isa<TestType>()) {
|
||||
printer << "test_type";
|
||||
return;
|
||||
|
||||
150
mlir/test/lib/Dialect/Test/TestTypeDefs.td
Normal file
150
mlir/test/lib/Dialect/Test/TestTypeDefs.td
Normal file
@@ -0,0 +1,150 @@
|
||||
//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// TableGen data type definitions for Test dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef TEST_TYPEDEFS
|
||||
#define TEST_TYPEDEFS
|
||||
|
||||
// To get the test dialect def.
|
||||
include "TestOps.td"
|
||||
|
||||
// All of the types will extend this class.
|
||||
class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
|
||||
|
||||
def SimpleTypeA : Test_Type<"SimpleA"> {
|
||||
let mnemonic = "smpla";
|
||||
|
||||
let printer = [{ $_printer << "smpla"; }];
|
||||
let parser = [{ return get($_ctxt); }];
|
||||
}
|
||||
|
||||
// A more complex parameterized type.
|
||||
def CompoundTypeA : Test_Type<"CompoundA"> {
|
||||
let mnemonic = "cmpnd_a";
|
||||
|
||||
// List of type parameters.
|
||||
let parameters = (
|
||||
ins
|
||||
"int":$widthOfSomething,
|
||||
"::mlir::Type":$oneType,
|
||||
// This is special syntax since ArrayRefs require allocation in the
|
||||
// constructor.
|
||||
ArrayRefParameter<
|
||||
"int", // The parameter C++ type.
|
||||
"An example of an array of ints" // Parameter description.
|
||||
>: $arrayOfInts
|
||||
);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
struct SomeCppStruct {};
|
||||
}];
|
||||
}
|
||||
|
||||
// An example of how one could implement a standard integer.
|
||||
def IntegerType : Test_Type<"TestInteger"> {
|
||||
let mnemonic = "int";
|
||||
let genVerifyInvariantsDecl = 1;
|
||||
let parameters = (
|
||||
ins
|
||||
// SignednessSemantics is defined below.
|
||||
"::mlir::TestIntegerType::SignednessSemantics":$signedness,
|
||||
"unsigned":$width
|
||||
);
|
||||
|
||||
// We define the printer inline.
|
||||
let printer = [{
|
||||
$_printer << "int<";
|
||||
printSignedness($_printer, getImpl()->signedness);
|
||||
$_printer << ", " << getImpl()->width << ">";
|
||||
}];
|
||||
|
||||
// The parser is defined here also.
|
||||
let parser = [{
|
||||
if (parser.parseLess()) return Type();
|
||||
SignednessSemantics signedness;
|
||||
if (parseSignedness($_parser, signedness)) return mlir::Type();
|
||||
if ($_parser.parseComma()) return Type();
|
||||
int width;
|
||||
if ($_parser.parseInteger(width)) return Type();
|
||||
if ($_parser.parseGreater()) return Type();
|
||||
return get(ctxt, signedness, width);
|
||||
}];
|
||||
|
||||
// Any extra code one wants in the type's class declaration.
|
||||
let extraClassDeclaration = [{
|
||||
/// Signedness semantics.
|
||||
enum SignednessSemantics {
|
||||
Signless, /// No signedness semantics
|
||||
Signed, /// Signed integer
|
||||
Unsigned, /// Unsigned integer
|
||||
};
|
||||
|
||||
/// This extra function is necessary since it doesn't include signedness
|
||||
static IntegerType getChecked(unsigned width, Location location);
|
||||
|
||||
/// Return true if this is a signless integer type.
|
||||
bool isSignless() const { return getSignedness() == Signless; }
|
||||
/// Return true if this is a signed integer type.
|
||||
bool isSigned() const { return getSignedness() == Signed; }
|
||||
/// Return true if this is an unsigned integer type.
|
||||
bool isUnsigned() const { return getSignedness() == Unsigned; }
|
||||
}];
|
||||
}
|
||||
|
||||
// A parent type for any type which is just a list of fields (e.g. structs,
|
||||
// unions).
|
||||
class FieldInfo_Type<string name> : Test_Type<name> {
|
||||
let parameters = (
|
||||
ins
|
||||
// An ArrayRef of something which requires allocation in the storage
|
||||
// constructor.
|
||||
ArrayRefOfSelfAllocationParameter<
|
||||
"::mlir::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
|
||||
"Models struct fields">: $fields
|
||||
);
|
||||
|
||||
// Prints the type in this format:
|
||||
// struct<[{field1Name, field1Type}, {field2Name, field2Type}]
|
||||
let printer = [{
|
||||
$_printer << "struct" << "<";
|
||||
for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
|
||||
const auto& field = getImpl()->fields[i];
|
||||
$_printer << "{" << field.name << "," << field.type << "}";
|
||||
if (i < getImpl()->fields.size() - 1)
|
||||
$_printer << ",";
|
||||
}
|
||||
$_printer << ">";
|
||||
}];
|
||||
|
||||
// Parses the above format
|
||||
let parser = [{
|
||||
llvm::SmallVector<FieldInfo, 4> parameters;
|
||||
if ($_parser.parseLess()) return Type();
|
||||
while (mlir::succeeded($_parser.parseOptionalLBrace())) {
|
||||
StringRef name;
|
||||
if ($_parser.parseKeyword(&name)) return Type();
|
||||
if ($_parser.parseComma()) return Type();
|
||||
Type type;
|
||||
if ($_parser.parseType(type)) return Type();
|
||||
if ($_parser.parseRBrace()) return Type();
|
||||
parameters.push_back(FieldInfo {name, type});
|
||||
if ($_parser.parseOptionalComma()) break;
|
||||
}
|
||||
if ($_parser.parseGreater()) return Type();
|
||||
return get($_ctxt, parameters);
|
||||
}];
|
||||
}
|
||||
|
||||
def StructType : FieldInfo_Type<"Struct"> {
|
||||
let mnemonic = "struct";
|
||||
}
|
||||
|
||||
#endif // TEST_TYPEDEFS
|
||||
117
mlir/test/lib/Dialect/Test/TestTypes.cpp
Normal file
117
mlir/test/lib/Dialect/Test/TestTypes.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains types defined by the TestDialect for testing various
|
||||
// features of MLIR.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "TestTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
// Custom parser for SignednessSemantics.
|
||||
static ParseResult
|
||||
parseSignedness(DialectAsmParser &parser,
|
||||
TestIntegerType::SignednessSemantics &result) {
|
||||
StringRef signStr;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseKeyword(&signStr))
|
||||
return failure();
|
||||
if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
|
||||
result = TestIntegerType::SignednessSemantics::Unsigned;
|
||||
else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
|
||||
result = TestIntegerType::SignednessSemantics::Signed;
|
||||
else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
|
||||
result = TestIntegerType::SignednessSemantics::Signless;
|
||||
else
|
||||
return parser.emitError(loc, "expected signed, unsigned, or none");
|
||||
return success();
|
||||
}
|
||||
|
||||
// Custom printer for SignednessSemantics.
|
||||
static void printSignedness(DialectAsmPrinter &printer,
|
||||
const TestIntegerType::SignednessSemantics &ss) {
|
||||
switch (ss) {
|
||||
case TestIntegerType::SignednessSemantics::Unsigned:
|
||||
printer << "unsigned";
|
||||
break;
|
||||
case TestIntegerType::SignednessSemantics::Signed:
|
||||
printer << "signed";
|
||||
break;
|
||||
case TestIntegerType::SignednessSemantics::Signless:
|
||||
printer << "none";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
|
||||
int widthOfSomething;
|
||||
Type oneType;
|
||||
SmallVector<int, 4> arrayOfInts;
|
||||
if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
|
||||
parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
|
||||
parser.parseLSquare())
|
||||
return Type();
|
||||
|
||||
int i;
|
||||
while (!*parser.parseOptionalInteger(i)) {
|
||||
arrayOfInts.push_back(i);
|
||||
if (parser.parseOptionalComma())
|
||||
break;
|
||||
}
|
||||
|
||||
if (parser.parseRSquare() || parser.parseGreater())
|
||||
return Type();
|
||||
|
||||
return get(ctxt, widthOfSomething, oneType, arrayOfInts);
|
||||
}
|
||||
void CompoundAType::print(DialectAsmPrinter &printer) const {
|
||||
printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
|
||||
<< ", [";
|
||||
auto intArray = getArrayOfInts();
|
||||
llvm::interleaveComma(intArray, printer);
|
||||
printer << "]>";
|
||||
}
|
||||
|
||||
// The functions don't need to be in the header file, but need to be in the mlir
|
||||
// namespace. Declare them here, then define them immediately below. Separating
|
||||
// the declaration and definition adheres to the LLVM coding standards.
|
||||
namespace mlir {
|
||||
// FieldInfo is used as part of a parameter, so equality comparison is
|
||||
// compulsory.
|
||||
static bool operator==(const FieldInfo &a, const FieldInfo &b);
|
||||
// FieldInfo is used as part of a parameter, so a hash will be computed.
|
||||
static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
|
||||
} // namespace mlir
|
||||
|
||||
// FieldInfo is used as part of a parameter, so equality comparison is
|
||||
// compulsory.
|
||||
static bool mlir::operator==(const FieldInfo &a, const FieldInfo &b) {
|
||||
return a.name == b.name && a.type == b.type;
|
||||
}
|
||||
|
||||
// FieldInfo is used as part of a parameter, so a hash will be computed.
|
||||
static llvm::hash_code mlir::hash_value(const FieldInfo &fi) { // NOLINT
|
||||
return llvm::hash_combine(fi.name, fi.type);
|
||||
}
|
||||
|
||||
// Example type validity checker.
|
||||
LogicalResult TestIntegerType::verifyConstructionInvariants(
|
||||
Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
|
||||
if (width > 8)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "TestTypeDefs.cpp.inc"
|
||||
@@ -14,11 +14,35 @@
|
||||
#ifndef MLIR_TESTTYPES_H
|
||||
#define MLIR_TESTTYPES_H
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// FieldInfo represents a field in the StructType data type. It is used as a
|
||||
/// parameter in TestTypeDefs.td.
|
||||
struct FieldInfo {
|
||||
StringRef name;
|
||||
Type type;
|
||||
|
||||
// Custom allocation called from generated constructor code
|
||||
FieldInfo allocateInto(TypeStorageAllocator &alloc) const {
|
||||
return FieldInfo{alloc.copyInto(name), type};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "TestTypeDefs.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#include "TestTypeInterfaces.h.inc"
|
||||
|
||||
/// This class is a simple test type that uses a generated interface.
|
||||
|
||||
24
mlir/test/mlir-tblgen/testdialect-typedefs.mlir
Normal file
24
mlir/test/mlir-tblgen/testdialect-typedefs.mlir
Normal file
@@ -0,0 +1,24 @@
|
||||
// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
|
||||
|
||||
//////////////
|
||||
// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir'
|
||||
|
||||
// CHECK: @simpleA(%arg0: !test.smpla)
|
||||
func @simpleA(%A : !test.smpla) -> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>)
|
||||
func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
|
||||
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
|
||||
func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
|
||||
return
|
||||
}
|
||||
132
mlir/test/mlir-tblgen/typedefs.td
Normal file
132
mlir/test/mlir-tblgen/typedefs.td
Normal file
@@ -0,0 +1,132 @@
|
||||
// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
|
||||
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
// DECL: #ifdef GET_TYPEDEF_CLASSES
|
||||
// DECL: #undef GET_TYPEDEF_CLASSES
|
||||
|
||||
// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
|
||||
// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
|
||||
|
||||
// DEF: #ifdef GET_TYPEDEF_LIST
|
||||
// DEF: #undef GET_TYPEDEF_LIST
|
||||
// DEF: ::mlir::test::SimpleAType,
|
||||
// DEF: ::mlir::test::CompoundAType,
|
||||
// DEF: ::mlir::test::IndexType,
|
||||
// DEF: ::mlir::test::SingleParameterType,
|
||||
// DEF: ::mlir::test::IntegerType
|
||||
|
||||
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
|
||||
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
|
||||
// DEF return ::mlir::Type();
|
||||
|
||||
def Test_Dialect: Dialect {
|
||||
// DECL-NOT: TestDialect
|
||||
// DEF-NOT: TestDialect
|
||||
let name = "TestDialect";
|
||||
let cppNamespace = "::mlir::test";
|
||||
}
|
||||
|
||||
class TestType<string name> : TypeDef<Test_Dialect, name> { }
|
||||
|
||||
def A_SimpleTypeA : TestType<"SimpleA"> {
|
||||
// DECL: class SimpleAType: public ::mlir::Type
|
||||
}
|
||||
|
||||
// A more complex parameterized type
|
||||
def B_CompoundTypeA : TestType<"CompoundA"> {
|
||||
let summary = "A more complex parameterized type";
|
||||
let description = "This type is to test a reasonably complex type";
|
||||
let mnemonic = "cmpnd_a";
|
||||
let parameters = (
|
||||
ins
|
||||
"int":$widthOfSomething,
|
||||
"::mlir::test::SimpleTypeA": $exampleTdType,
|
||||
"SomeCppStruct": $exampleCppType,
|
||||
ArrayRefParameter<"int", "Matrix dimensions">:$dims
|
||||
);
|
||||
|
||||
let genVerifyInvariantsDecl = 1;
|
||||
|
||||
// DECL-LABEL: class CompoundAType: public ::mlir::Type
|
||||
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
|
||||
// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
|
||||
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
|
||||
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
|
||||
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
|
||||
// DECL: int getWidthOfSomething() const;
|
||||
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
|
||||
// DECL: SomeCppStruct getExampleCppType() const;
|
||||
}
|
||||
|
||||
def C_IndexType : TestType<"Index"> {
|
||||
let mnemonic = "index";
|
||||
|
||||
let parameters = (
|
||||
ins
|
||||
StringRefParameter<"Label for index">:$label
|
||||
);
|
||||
|
||||
// DECL-LABEL: class IndexType: public ::mlir::Type
|
||||
// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
|
||||
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
|
||||
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
|
||||
}
|
||||
|
||||
def D_SingleParameterType : TestType<"SingleParameter"> {
|
||||
let parameters = (
|
||||
ins
|
||||
"int": $num
|
||||
);
|
||||
// DECL-LABEL: struct SingleParameterTypeStorage;
|
||||
// DECL-LABEL: class SingleParameterType
|
||||
// DECL-NEXT: detail::SingleParameterTypeStorage
|
||||
}
|
||||
|
||||
def E_IntegerType : TestType<"Integer"> {
|
||||
let mnemonic = "int";
|
||||
let genVerifyInvariantsDecl = 1;
|
||||
let parameters = (
|
||||
ins
|
||||
"SignednessSemantics":$signedness,
|
||||
TypeParameter<"unsigned", "Bitwdith of integer">:$width
|
||||
);
|
||||
|
||||
// DECL-LABEL: IntegerType: public ::mlir::Type
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Signedness semantics.
|
||||
enum SignednessSemantics {
|
||||
Signless, /// No signedness semantics
|
||||
Signed, /// Signed integer
|
||||
Unsigned, /// Unsigned integer
|
||||
};
|
||||
|
||||
/// This extra function is necessary since it doesn't include signedness
|
||||
static IntegerType getChecked(unsigned width, Location location);
|
||||
|
||||
/// Return true if this is a signless integer type.
|
||||
bool isSignless() const { return getSignedness() == Signless; }
|
||||
/// Return true if this is a signed integer type.
|
||||
bool isSigned() const { return getSignedness() == Signed; }
|
||||
/// Return true if this is an unsigned integer type.
|
||||
bool isUnsigned() const { return getSignedness() == Unsigned; }
|
||||
}];
|
||||
|
||||
// DECL: /// Signedness semantics.
|
||||
// DECL-NEXT: enum SignednessSemantics {
|
||||
// DECL-NEXT: Signless, /// No signedness semantics
|
||||
// DECL-NEXT: Signed, /// Signed integer
|
||||
// DECL-NEXT: Unsigned, /// Unsigned integer
|
||||
// DECL-NEXT: };
|
||||
// DECL: /// This extra function is necessary since it doesn't include signedness
|
||||
// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location);
|
||||
|
||||
// DECL: /// Return true if this is a signless integer type.
|
||||
// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; }
|
||||
// DECL-NEXT: /// Return true if this is a signed integer type.
|
||||
// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; }
|
||||
// DECL-NEXT: /// Return true if this is an unsigned integer type.
|
||||
// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; }
|
||||
}
|
||||
@@ -20,6 +20,7 @@ add_tablegen(mlir-tblgen MLIR
|
||||
RewriterGen.cpp
|
||||
SPIRVUtilsGen.cpp
|
||||
StructsGen.cpp
|
||||
TypeDefGen.cpp
|
||||
)
|
||||
|
||||
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/TypeDef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
@@ -23,6 +24,8 @@
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
@@ -155,12 +158,67 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeDef Documentation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Emit the assembly format of a type.
|
||||
static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
td.getParameters(parameters);
|
||||
if (parameters.size() == 0) {
|
||||
os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
|
||||
<< "`\n";
|
||||
return;
|
||||
}
|
||||
|
||||
os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "."
|
||||
<< td.getMnemonic() << "<\n";
|
||||
for (auto *it = parameters.begin(), *e = parameters.end(); it < e; ++it) {
|
||||
os << " " << it->getSyntax();
|
||||
if (it < parameters.end() - 1)
|
||||
os << ",";
|
||||
os << " # " << it->getName() << "\n";
|
||||
}
|
||||
os << ">\n```\n";
|
||||
}
|
||||
|
||||
static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
|
||||
os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName());
|
||||
|
||||
// Emit the summary, syntax, and description if present.
|
||||
if (td.hasSummary())
|
||||
os << "\n" << td.getSummary() << "\n";
|
||||
if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" &&
|
||||
td.getParserCode() && *td.getParserCode() == "")
|
||||
emitTypeAssemblyFormat(td, os);
|
||||
if (td.hasDescription())
|
||||
mlir::tblgen::emitDescription(td.getDescription(), os);
|
||||
|
||||
// Emit attribute documentation.
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
td.getParameters(parameters);
|
||||
if (parameters.size() != 0) {
|
||||
os << "\n#### Type parameters:\n\n";
|
||||
os << "| Parameter | C++ type | Description |\n"
|
||||
<< "| :-------: | :-------: | ----------- |\n";
|
||||
for (const auto &it : parameters) {
|
||||
auto desc = it.getDescription();
|
||||
os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | "
|
||||
<< (desc ? *desc : "") << " |\n";
|
||||
}
|
||||
}
|
||||
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect Documentation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
|
||||
ArrayRef<Type> types, raw_ostream &os) {
|
||||
ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
|
||||
raw_ostream &os) {
|
||||
os << "# '" << dialect.getName() << "' Dialect\n\n";
|
||||
emitIfNotEmpty(dialect.getSummary(), os);
|
||||
emitIfNotEmpty(dialect.getDescription(), os);
|
||||
@@ -169,7 +227,7 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
|
||||
|
||||
// TODO: Add link between use and def for types
|
||||
if (!types.empty()) {
|
||||
os << "## Type definition\n\n";
|
||||
os << "## Type constraint definition\n\n";
|
||||
for (const Type &type : types)
|
||||
emitTypeDoc(type, os);
|
||||
}
|
||||
@@ -179,28 +237,43 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
|
||||
for (const Operator &op : ops)
|
||||
emitOpDoc(op, os);
|
||||
}
|
||||
|
||||
if (!typeDefs.empty()) {
|
||||
os << "## Type definition\n\n";
|
||||
for (const TypeDef &td : typeDefs)
|
||||
emitTypeDefDoc(td, os);
|
||||
}
|
||||
}
|
||||
|
||||
static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
||||
const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op");
|
||||
const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
|
||||
const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
|
||||
|
||||
std::set<Dialect> dialectsWithDocs;
|
||||
std::map<Dialect, std::vector<Operator>> dialectOps;
|
||||
std::map<Dialect, std::vector<Type>> dialectTypes;
|
||||
std::map<Dialect, std::vector<TypeDef>> dialectTypeDefs;
|
||||
for (auto *opDef : opDefs) {
|
||||
Operator op(opDef);
|
||||
dialectOps[op.getDialect()].push_back(op);
|
||||
dialectsWithDocs.insert(op.getDialect());
|
||||
}
|
||||
for (auto *typeDef : typeDefs) {
|
||||
Type type(typeDef);
|
||||
if (auto dialect = type.getDialect())
|
||||
dialectTypes[dialect].push_back(type);
|
||||
}
|
||||
for (auto *typeDef : typeDefDefs) {
|
||||
TypeDef type(typeDef);
|
||||
dialectTypeDefs[type.getDialect()].push_back(type);
|
||||
dialectsWithDocs.insert(type.getDialect());
|
||||
}
|
||||
|
||||
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
|
||||
for (const auto &dialectWithOps : dialectOps)
|
||||
emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
|
||||
dialectTypes[dialectWithOps.first], os);
|
||||
for (auto dialect : dialectsWithDocs)
|
||||
emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect],
|
||||
dialectTypeDefs[dialect], os);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
561
mlir/tools/mlir-tblgen/TypeDefGen.cpp
Normal file
561
mlir/tools/mlir-tblgen/TypeDefGen.cpp
Normal file
@@ -0,0 +1,561 @@
|
||||
//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// TypeDefGen uses the description of typeDefs to generate C++ definitions.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/TypeDef.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
#define DEBUG_TYPE "mlir-tblgen-typedefgen"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
|
||||
static llvm::cl::opt<std::string>
|
||||
selectedDialect("typedefs-dialect",
|
||||
llvm::cl::desc("Gen types for this dialect"),
|
||||
llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
|
||||
|
||||
/// Find all the TypeDefs for the specified dialect. If no dialect specified and
|
||||
/// can only find one dialect's types, use that.
|
||||
static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
|
||||
SmallVectorImpl<TypeDef> &typeDefs) {
|
||||
auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
|
||||
auto defs = llvm::map_range(
|
||||
recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
|
||||
if (defs.empty())
|
||||
return;
|
||||
|
||||
StringRef dialectName;
|
||||
if (selectedDialect.getNumOccurrences() == 0) {
|
||||
if (defs.empty())
|
||||
return;
|
||||
|
||||
llvm::SmallSet<Dialect, 4> dialects;
|
||||
for (const TypeDef &typeDef : defs)
|
||||
dialects.insert(typeDef.getDialect());
|
||||
if (dialects.size() != 1)
|
||||
llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
|
||||
"select one via '--typedefs-dialect'");
|
||||
|
||||
dialectName = (*dialects.begin()).getName();
|
||||
} else if (selectedDialect.getNumOccurrences() == 1) {
|
||||
dialectName = selectedDialect.getValue();
|
||||
} else {
|
||||
llvm::PrintFatalError("Cannot select multiple dialects for which to "
|
||||
"generate types via '--typedefs-dialect'.");
|
||||
}
|
||||
|
||||
for (const TypeDef &typeDef : defs)
|
||||
if (typeDef.getDialect().getName().equals(dialectName))
|
||||
typeDefs.push_back(typeDef);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Pass an instance of this class to llvm::formatv() to emit a comma separated
|
||||
/// list of parameters in the format by 'EmitFormat'.
|
||||
class TypeParamCommaFormatter : public llvm::detail::format_adapter {
|
||||
public:
|
||||
/// Choose the output format
|
||||
enum EmitFormat {
|
||||
/// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
|
||||
/// [...]".
|
||||
TypeNamePairs,
|
||||
|
||||
/// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
|
||||
/// [...]".
|
||||
TypeNamePairsPrependComma,
|
||||
|
||||
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
|
||||
TypeNameInitializer
|
||||
};
|
||||
|
||||
TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
|
||||
: emitFormat(emitFormat), params(params) {}
|
||||
|
||||
/// llvm::formatv will call this function when using an instance as a
|
||||
/// replacement value.
|
||||
void format(raw_ostream &os, StringRef options) {
|
||||
if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
|
||||
os << ", ";
|
||||
switch (emitFormat) {
|
||||
case EmitFormat::TypeNamePairs:
|
||||
case EmitFormat::TypeNamePairsPrependComma:
|
||||
interleaveComma(params, os,
|
||||
[&](const TypeParameter &p) { emitTypeNamePair(p, os); });
|
||||
break;
|
||||
case EmitFormat::TypeNameInitializer:
|
||||
interleaveComma(params, os, [&](const TypeParameter &p) {
|
||||
emitTypeNameInitializer(p, os);
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// Emit "paramType paramName".
|
||||
static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) {
|
||||
os << param.getCppType() << " " << param.getName();
|
||||
}
|
||||
// Emit "paramName(paramName)"
|
||||
void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) {
|
||||
os << param.getName() << "(" << param.getName() << ")";
|
||||
}
|
||||
|
||||
EmitFormat emitFormat;
|
||||
ArrayRef<TypeParameter> params;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: TypeDef declarations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The code block for the start of a typeDef class declaration -- singleton
|
||||
/// case.
|
||||
///
|
||||
/// {0}: The name of the typeDef class.
|
||||
static const char *const typeDefDeclSingletonBeginStr = R"(
|
||||
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
|
||||
public:
|
||||
/// Inherit some necessary constructors from 'TypeBase'.
|
||||
using Base::Base;
|
||||
|
||||
)";
|
||||
|
||||
/// The code block for the start of a typeDef class declaration -- parametric
|
||||
/// case.
|
||||
///
|
||||
/// {0}: The name of the typeDef class.
|
||||
/// {1}: The typeDef storage class namespace.
|
||||
/// {2}: The storage class name.
|
||||
/// {3}: The list of parameters with types.
|
||||
static const char *const typeDefDeclParametricBeginStr = R"(
|
||||
namespace {1} {
|
||||
struct {2};
|
||||
}
|
||||
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
|
||||
{1}::{2}> {{
|
||||
public:
|
||||
/// Inherit some necessary constructors from 'TypeBase'.
|
||||
using Base::Base;
|
||||
|
||||
)";
|
||||
|
||||
/// The snippet for print/parse.
|
||||
static const char *const typeDefParsePrint = R"(
|
||||
static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
|
||||
void print(::mlir::DialectAsmPrinter& printer) const;
|
||||
)";
|
||||
|
||||
/// The code block for the verifyConstructionInvariants and getChecked.
|
||||
///
|
||||
/// {0}: List of parameters, parameters style.
|
||||
/// {1}: C++ type class name.
|
||||
static const char *const typeDefDeclVerifyStr = R"(
|
||||
static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
|
||||
static {1} getChecked(Location loc{0});
|
||||
)";
|
||||
|
||||
/// Generate the declaration for the given typeDef class.
|
||||
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
|
||||
SmallVector<TypeParameter, 4> params;
|
||||
typeDef.getParameters(params);
|
||||
|
||||
// Emit the beginning string template: either the singleton or parametric
|
||||
// template.
|
||||
if (typeDef.getNumParameters() == 0)
|
||||
os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
|
||||
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
|
||||
else
|
||||
os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
|
||||
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
|
||||
|
||||
// Emit the extra declarations first in case there's a type definition in
|
||||
// there.
|
||||
if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
|
||||
os << *extraDecl << "\n";
|
||||
|
||||
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
|
||||
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
|
||||
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
|
||||
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
|
||||
|
||||
// Emit the verify invariants declaration.
|
||||
if (typeDef.genVerifyInvariantsDecl())
|
||||
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
|
||||
typeDef.getCppClassName());
|
||||
|
||||
// Emit the mnenomic, if specified.
|
||||
if (auto mnenomic = typeDef.getMnemonic()) {
|
||||
os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
|
||||
<< "\"; }\n";
|
||||
|
||||
// If mnemonic specified, emit print/parse declarations.
|
||||
os << typeDefParsePrint;
|
||||
}
|
||||
|
||||
if (typeDef.genAccessors()) {
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
typeDef.getParameters(parameters);
|
||||
|
||||
for (TypeParameter ¶meter : parameters) {
|
||||
SmallString<16> name = parameter.getName();
|
||||
name[0] = llvm::toUpper(name[0]);
|
||||
os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
|
||||
}
|
||||
}
|
||||
|
||||
// End the typeDef decl.
|
||||
os << " };\n";
|
||||
}
|
||||
|
||||
/// Main entry point for decls.
|
||||
static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
emitSourceFileHeader("TypeDef Declarations", os);
|
||||
|
||||
SmallVector<TypeDef, 16> typeDefs;
|
||||
findAllTypeDefs(recordKeeper, typeDefs);
|
||||
|
||||
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
|
||||
if (typeDefs.size() > 0) {
|
||||
NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
|
||||
|
||||
// Well known print/parse dispatch function declarations. These are called
|
||||
// from Dialect::parseType() and Dialect::printType() methods.
|
||||
os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
|
||||
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
|
||||
os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
|
||||
"::mlir::DialectAsmPrinter& printer);\n";
|
||||
os << "\n";
|
||||
|
||||
// Declare all the type classes first (in case they reference each other).
|
||||
for (const TypeDef &typeDef : typeDefs)
|
||||
os << " class " << typeDef.getCppClassName() << ";\n";
|
||||
|
||||
// Declare all the typedefs.
|
||||
for (const TypeDef &typeDef : typeDefs)
|
||||
emitTypeDefDecl(typeDef, os);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: TypeDef list
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
|
||||
raw_ostream &os) {
|
||||
IfDefScope scope("GET_TYPEDEF_LIST", os);
|
||||
for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
|
||||
os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
|
||||
if (i < typeDefs.end() - 1)
|
||||
os << ",\n";
|
||||
else
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: TypeDef definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Beginning of storage class.
|
||||
/// {0}: Storage class namespace.
|
||||
/// {1}: Storage class c++ name.
|
||||
/// {2}: Parameters parameters.
|
||||
/// {3}: Parameter initialzer string.
|
||||
/// {4}: Parameter name list.
|
||||
/// {5}: Parameter types.
|
||||
static const char *const typeDefStorageClassBegin = R"(
|
||||
namespace {0} {{
|
||||
struct {1} : public ::mlir::TypeStorage {{
|
||||
{1} ({2})
|
||||
: {3} {{ }
|
||||
|
||||
/// The hash key for this storage is a pair of the integer and type params.
|
||||
using KeyTy = std::tuple<{5}>;
|
||||
|
||||
/// Define the comparison function for the key type.
|
||||
bool operator==(const KeyTy &key) const {{
|
||||
return key == KeyTy({4});
|
||||
}
|
||||
)";
|
||||
|
||||
/// The storage class' constructor template.
|
||||
/// {0}: storage class name.
|
||||
static const char *const typeDefStorageClassConstructorBegin = R"(
|
||||
/// Define a construction method for creating a new instance of this storage.
|
||||
static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
|
||||
)";
|
||||
|
||||
/// The storage class' constructor return template.
|
||||
/// {0}: storage class name.
|
||||
/// {1}: list of parameters.
|
||||
static const char *const typeDefStorageClassConstructorReturn = R"(
|
||||
return new (allocator.allocate<{0}>())
|
||||
{0}({1});
|
||||
}
|
||||
)";
|
||||
|
||||
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
|
||||
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
typeDef.getParameters(parameters);
|
||||
auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
|
||||
for (TypeParameter ¶meter : parameters) {
|
||||
auto allocCode = parameter.getAllocator();
|
||||
if (allocCode) {
|
||||
fmtCtxt.withSelf(parameter.getName());
|
||||
fmtCtxt.addSubst("_dst", parameter.getName());
|
||||
os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Emit the storage class code for type 'typeDef'.
|
||||
/// This includes (in-order):
|
||||
/// 1) typeDefStorageClassBegin, which includes:
|
||||
/// - The class constructor.
|
||||
/// - The KeyTy definition.
|
||||
/// - The equality (==) operator.
|
||||
/// 2) The hashKey method.
|
||||
/// 3) The construct method.
|
||||
/// 4) The list of parameters as the storage class member variables.
|
||||
static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
typeDef.getParameters(parameters);
|
||||
|
||||
// Initialize a bunch of variables to be used later on.
|
||||
auto parameterNames = map_range(
|
||||
parameters, [](TypeParameter parameter) { return parameter.getName(); });
|
||||
auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
|
||||
return parameter.getCppType();
|
||||
});
|
||||
auto parameterList = join(parameterNames, ", ");
|
||||
auto parameterTypeList = join(parameterTypes, ", ");
|
||||
|
||||
// 1) Emit most of the storage class up until the hashKey body.
|
||||
os << formatv(
|
||||
typeDefStorageClassBegin, typeDef.getStorageNamespace(),
|
||||
typeDef.getStorageClassName(),
|
||||
TypeParamCommaFormatter(
|
||||
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
|
||||
TypeParamCommaFormatter(
|
||||
TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
|
||||
parameterList, parameterTypeList);
|
||||
|
||||
// 2) Emit the haskKey method.
|
||||
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
|
||||
// Extract each parameter from the key.
|
||||
for (size_t i = 0, e = parameters.size(); i < e; ++i)
|
||||
os << formatv(" const auto &{0} = std::get<{1}>(key);\n",
|
||||
parameters[i].getName(), i);
|
||||
// Then combine them all. This requires all the parameters types to have a
|
||||
// hash_value defined.
|
||||
os << " return ::llvm::hash_combine(";
|
||||
interleaveComma(parameterNames, os);
|
||||
os << ");\n";
|
||||
os << " }\n";
|
||||
|
||||
// 3) Emit the construct method.
|
||||
if (typeDef.hasStorageCustomConstructor())
|
||||
// If user wants to build the storage constructor themselves, declare it
|
||||
// here and then they can write the definition elsewhere.
|
||||
os << " static " << typeDef.getStorageClassName()
|
||||
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
|
||||
"&key);\n";
|
||||
else {
|
||||
// If not, autogenerate one.
|
||||
|
||||
// First, unbox the parameters.
|
||||
os << formatv(typeDefStorageClassConstructorBegin,
|
||||
typeDef.getStorageClassName());
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
os << formatv(" auto {0} = std::get<{1}>(key);\n",
|
||||
parameters[i].getName(), i);
|
||||
}
|
||||
// Second, reassign the parameter variables with allocation code, if it's
|
||||
// specified.
|
||||
emitParameterAllocationCode(typeDef, os);
|
||||
|
||||
// Last, return an allocated copy.
|
||||
os << formatv(typeDefStorageClassConstructorReturn,
|
||||
typeDef.getStorageClassName(), parameterList);
|
||||
}
|
||||
|
||||
// 4) Emit the parameters as storage class members.
|
||||
for (auto parameter : parameters) {
|
||||
os << " " << parameter.getCppType() << " " << parameter.getName()
|
||||
<< ";\n";
|
||||
}
|
||||
os << " };\n";
|
||||
|
||||
os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
|
||||
}
|
||||
|
||||
/// Emit the parser and printer for a particular type, if they're specified.
|
||||
void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
|
||||
// Emit the printer code, if specified.
|
||||
if (auto printerCode = typeDef.getPrinterCode()) {
|
||||
// Both the mnenomic and printerCode must be defined (for parity with
|
||||
// parserCode).
|
||||
os << "void " << typeDef.getCppClassName()
|
||||
<< "::print(::mlir::DialectAsmPrinter& printer) const {\n";
|
||||
if (*printerCode == "") {
|
||||
// If no code specified, emit error.
|
||||
PrintFatalError(typeDef.getLoc(),
|
||||
typeDef.getName() +
|
||||
": printer (if specified) must have non-empty code");
|
||||
}
|
||||
auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
|
||||
os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
|
||||
}
|
||||
|
||||
// emit a parser, if specified.
|
||||
if (auto parserCode = typeDef.getParserCode()) {
|
||||
// The mnenomic must be defined so the dispatcher knows how to dispatch.
|
||||
os << "::mlir::Type " << typeDef.getCppClassName()
|
||||
<< "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
|
||||
"parser) "
|
||||
"{\n";
|
||||
if (*parserCode == "") {
|
||||
// if no code specified, emit error.
|
||||
PrintFatalError(typeDef.getLoc(),
|
||||
typeDef.getName() +
|
||||
": parser (if specified) must have non-empty code");
|
||||
}
|
||||
auto fmtCtxt =
|
||||
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
|
||||
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
|
||||
}
|
||||
}
|
||||
|
||||
/// Print all the typedef-specific definition code.
|
||||
static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
|
||||
NamespaceEmitter ns(os, typeDef.getDialect());
|
||||
SmallVector<TypeParameter, 4> parameters;
|
||||
typeDef.getParameters(parameters);
|
||||
|
||||
// Emit the storage class, if requested and necessary.
|
||||
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
|
||||
emitStorageClass(typeDef, os);
|
||||
|
||||
os << llvm::formatv(
|
||||
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
|
||||
" return Base::get(ctxt",
|
||||
typeDef.getCppClassName(),
|
||||
TypeParamCommaFormatter(
|
||||
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
|
||||
parameters));
|
||||
for (TypeParameter ¶m : parameters)
|
||||
os << ", " << param.getName();
|
||||
os << ");\n}\n";
|
||||
|
||||
// Emit the parameter accessors.
|
||||
if (typeDef.genAccessors())
|
||||
for (const TypeParameter ¶meter : parameters) {
|
||||
SmallString<16> name = parameter.getName();
|
||||
name[0] = llvm::toUpper(name[0]);
|
||||
os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
|
||||
parameter.getCppType(), name, parameter.getName(),
|
||||
typeDef.getCppClassName());
|
||||
}
|
||||
|
||||
// If mnemonic is specified maybe print definitions for the parser and printer
|
||||
// code, if they're specified.
|
||||
if (typeDef.getMnemonic())
|
||||
emitParserPrinter(typeDef, os);
|
||||
}
|
||||
|
||||
/// Emit the dialect printer/parser dispatcher. User's code should call these
|
||||
/// functions from their dialect's print/parse methods.
|
||||
static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
|
||||
raw_ostream &os) {
|
||||
if (typeDefs.size() == 0)
|
||||
return;
|
||||
const Dialect &dialect = typeDefs.begin()->getDialect();
|
||||
NamespaceEmitter ns(os, dialect);
|
||||
|
||||
// The parser dispatch is just a list of if-elses, matching on the mnemonic
|
||||
// and calling the class's parse function.
|
||||
os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
|
||||
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
|
||||
for (const TypeDef &typeDef : typeDefs)
|
||||
if (typeDef.getMnemonic())
|
||||
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
|
||||
"{0}::{1}::parse(ctxt, parser);\n",
|
||||
typeDef.getDialect().getCppNamespace(),
|
||||
typeDef.getCppClassName());
|
||||
os << " return ::mlir::Type();\n";
|
||||
os << "}\n\n";
|
||||
|
||||
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
|
||||
// printer.
|
||||
os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
|
||||
"::mlir::DialectAsmPrinter& printer) {\n"
|
||||
<< " ::mlir::LogicalResult found = ::mlir::success();\n"
|
||||
<< " ::llvm::TypeSwitch<::mlir::Type>(type)\n";
|
||||
for (auto typeDef : typeDefs)
|
||||
if (typeDef.getMnemonic())
|
||||
os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ "
|
||||
"t.dyn_cast<{0}::{1}>().print(printer); })\n",
|
||||
typeDef.getDialect().getCppNamespace(),
|
||||
typeDef.getCppClassName());
|
||||
os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
|
||||
"});\n"
|
||||
<< " return found;\n"
|
||||
<< "}\n\n";
|
||||
}
|
||||
|
||||
/// Entry point for typedef definitions.
|
||||
static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
emitSourceFileHeader("TypeDef Definitions", os);
|
||||
|
||||
SmallVector<TypeDef, 16> typeDefs;
|
||||
findAllTypeDefs(recordKeeper, typeDefs);
|
||||
emitTypeDefList(typeDefs, os);
|
||||
|
||||
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
|
||||
emitParsePrintDispatch(typeDefs, os);
|
||||
for (auto typeDef : typeDefs)
|
||||
emitTypeDefDef(typeDef, os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GEN: TypeDef registration hooks
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static mlir::GenRegistration
|
||||
genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
return emitTypeDefDefs(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
|
||||
[](const llvm::RecordKeeper &records, raw_ostream &os) {
|
||||
return emitTypeDefDecls(records, os);
|
||||
});
|
||||
Reference in New Issue
Block a user