[MLIR] Add support for defining Types in tblgen

Adds a TypeDef class to OpBase and backing generation code. Allows one
to define the Type, its parameters, and printer/parser methods in ODS.
Can generate the Type C++ class, accessors, storage class, per-parameter
custom allocators (for the storage constructor), and documentation.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D86904
This commit is contained in:
John Demme
2020-10-13 22:07:27 +00:00
parent ef3d17482f
commit 5fe53c4128
15 changed files with 1518 additions and 9 deletions

View File

@@ -9,6 +9,8 @@ function(add_mlir_dialect dialect dialect_namespace)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cpp.inc -gen-op-defs)
mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls)
mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs)
mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace})
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)

View File

@@ -2364,4 +2364,116 @@ def location;
// so to replace the matched DAG with an existing SSA value.
def replaceWithValue;
//===----------------------------------------------------------------------===//
// Data type generation
//===----------------------------------------------------------------------===//
// Define a new type belonging to a dialect and called 'name'.
class TypeDef<Dialect owningdialect, string name> {
Dialect dialect = owningdialect;
string cppClassName = name # "Type";
// Short summary of the type.
string summary = ?;
// The longer description of this type.
string description = ?;
// Name of storage class to generate or use.
string storageClass = name # "TypeStorage";
// Namespace (withing dialect c++ namespace) in which the storage class
// resides.
string storageNamespace = "detail";
// Specify if the storage class is to be generated.
bit genStorageClass = 1;
// Specify that the generated storage class has a constructor which is written
// in C++.
bit hasStorageCustomConstructor = 0;
// The list of parameters for this type. Parameters will become both
// parameters to the get() method and storage class member variables.
//
// The format of this dag is:
// (ins
// "<c++ type>":$param1Name,
// "<c++ type>":$param2Name,
// TypeParameter<"c++ type", "param description">:$param3Name)
// TypeParameters (or more likely one of their subclasses) are required to add
// more information about the parameter, specifically:
// - Documentation
// - Code to allocate the parameter (if allocation is needed in the storage
// class constructor)
//
// For example:
// (ins
// "int":$width,
// ArrayRefParameter<"bool", "list of bools">:$yesNoArray)
//
// (ArrayRefParameter is a subclass of TypeParameter which has allocation code
// for re-allocating ArrayRefs. It is defined below.)
dag parameters = (ins);
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
// printer/parser for this type.
string mnemonic = ?;
// If 'mnemonic' specified,
// If null, generate just the declarations.
// If a non-empty code block, just use that code as the definition code.
// Error if an empty code block.
code printer = ?;
code parser = ?;
// If set, generate accessors for each Type parameter.
bit genAccessors = 1;
// Generate the verifyConstructionInvariants declaration and getChecked
// method.
bit genVerifyInvariantsDecl = 0;
// Extra code to include in the class declaration.
code extraClassDeclaration = [{}];
}
// 'Parameters' should be subclasses of this or simple strings (which is a
// shorthand for TypeParameter<"C++Type">).
class TypeParameter<string type, string desc> {
// Custom memory allocation code for storage constructor.
code allocator = ?;
// The C++ type of this parameter.
string cppType = type;
// A description of this parameter.
string description = desc;
// The format string for the asm syntax (documentation only).
string syntax = ?;
}
// For StringRefs, which require allocation.
class StringRefParameter<string desc> :
TypeParameter<"::llvm::StringRef", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}
// For standard ArrayRefs, which require allocation.
class ArrayRefParameter<string arrayOf, string desc> :
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{$_dst = $_allocator.copyInto($_self);}];
}
// For classes which require allocation and have their own allocateInto method.
class SelfAllocationParameter<string type, string desc> :
TypeParameter<type, desc> {
let allocator = [{$_dst = $_self.allocateInto($_allocator);}];
}
// For ArrayRefs which contain things which allocate themselves.
class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> {
let allocator = [{
llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields;
for (size_t i = 0, e = $_self.size(); i < e; ++i)
tmpFields.push_back($_self[i].allocateInto($_allocator));
$_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields));
}];
}
#endif // OP_BASE

View File

@@ -0,0 +1,135 @@
//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TypeDef wrapper to simplify using TableGen Record defining a MLIR type.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TABLEGEN_TYPEDEF_H
#define MLIR_TABLEGEN_TYPEDEF_H
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Dialect.h"
namespace llvm {
class Record;
class DagInit;
class SMLoc;
} // namespace llvm
namespace mlir {
namespace tblgen {
class TypeParameter;
/// Wrapper class that contains a TableGen TypeDef's record and provides helper
/// methods for accessing them.
class TypeDef {
public:
explicit TypeDef(const llvm::Record *def) : def(def) {}
// Get the dialect for which this type belongs.
Dialect getDialect() const;
// Returns the name of this TypeDef record.
StringRef getName() const;
// Query functions for the documentation of the operator.
bool hasDescription() const;
StringRef getDescription() const;
bool hasSummary() const;
StringRef getSummary() const;
// Returns the name of the C++ class to generate.
StringRef getCppClassName() const;
// Returns the name of the storage class for this type.
StringRef getStorageClassName() const;
// Returns the C++ namespace for this types storage class.
StringRef getStorageNamespace() const;
// Returns true if we should generate the storage class.
bool genStorageClass() const;
// Indicates whether or not to generate the storage class constructor.
bool hasStorageCustomConstructor() const;
// Fill a list with this types parameters. See TypeDef in OpBase.td for
// documentation of parameter usage.
void getParameters(SmallVectorImpl<TypeParameter> &) const;
// Return the number of type parameters
unsigned getNumParameters() const;
// Return the keyword/mnemonic to use in the printer/parser methods if we are
// supposed to auto-generate them.
Optional<StringRef> getMnemonic() const;
// Returns the code to use as the types printer method. If not specified,
// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getPrinterCode() const;
// Returns the code to use as the types parser method. If not specified,
// return a non-value. Otherwise, return the contents of that code block.
Optional<StringRef> getParserCode() const;
// Returns true if the accessors based on the types parameters should be
// generated.
bool genAccessors() const;
// Return true if we need to generate the verifyConstructionInvariants
// declaration and getChecked method.
bool genVerifyInvariantsDecl() const;
// Returns the dialects extra class declaration code.
Optional<StringRef> getExtraDecls() const;
// Get the code location (for error printing).
ArrayRef<llvm::SMLoc> getLoc() const;
// Returns whether two TypeDefs are equal by checking the equality of the
// underlying record.
bool operator==(const TypeDef &other) const;
// Compares two TypeDefs by comparing the names of the dialects.
bool operator<(const TypeDef &other) const;
// Returns whether the TypeDef is defined.
operator bool() const { return def != nullptr; }
private:
const llvm::Record *def;
};
// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs
// to parameterize them.
class TypeParameter {
public:
explicit TypeParameter(const llvm::DagInit *def, unsigned num)
: def(def), num(num) {}
// Get the parameter name.
StringRef getName() const;
// If specified, get the custom allocator code for this parameter.
llvm::Optional<StringRef> getAllocator() const;
// Get the C++ type of this parameter.
StringRef getCppType() const;
// Get a description of this parameter for documentation purposes.
llvm::Optional<StringRef> getDescription() const;
// Get the assembly syntax documentation.
StringRef getSyntax() const;
private:
const llvm::DagInit *def;
const unsigned num;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_TYPEDEF_H

View File

@@ -25,6 +25,7 @@ llvm_add_library(MLIRTableGen STATIC
SideEffects.cpp
Successor.cpp
Type.cpp
TypeDef.cpp
DISABLE_LLVM_LINK_LLVM_DYLIB

View File

@@ -0,0 +1,160 @@
//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
Dialect TypeDef::getDialect() const {
auto *dialectDef =
dyn_cast<llvm::DefInit>(def->getValue("dialect")->getValue());
if (dialectDef == nullptr)
return Dialect(nullptr);
return Dialect(dialectDef->getDef());
}
StringRef TypeDef::getName() const { return def->getName(); }
StringRef TypeDef::getCppClassName() const {
return def->getValueAsString("cppClassName");
}
bool TypeDef::hasDescription() const {
const llvm::RecordVal *s = def->getValue("description");
return s != nullptr && isa<llvm::StringInit>(s->getValue());
}
StringRef TypeDef::getDescription() const {
return def->getValueAsString("description");
}
bool TypeDef::hasSummary() const {
const llvm::RecordVal *s = def->getValue("summary");
return s != nullptr && isa<llvm::StringInit>(s->getValue());
}
StringRef TypeDef::getSummary() const {
return def->getValueAsString("summary");
}
StringRef TypeDef::getStorageClassName() const {
return def->getValueAsString("storageClass");
}
StringRef TypeDef::getStorageNamespace() const {
return def->getValueAsString("storageNamespace");
}
bool TypeDef::genStorageClass() const {
return def->getValueAsBit("genStorageClass");
}
bool TypeDef::hasStorageCustomConstructor() const {
return def->getValueAsBit("hasStorageCustomConstructor");
}
void TypeDef::getParameters(SmallVectorImpl<TypeParameter> &parameters) const {
auto *parametersDag = def->getValueAsDag("parameters");
if (parametersDag != nullptr) {
size_t numParams = parametersDag->getNumArgs();
for (unsigned i = 0; i < numParams; i++)
parameters.push_back(TypeParameter(parametersDag, i));
}
}
unsigned TypeDef::getNumParameters() const {
auto *parametersDag = def->getValueAsDag("parameters");
return parametersDag ? parametersDag->getNumArgs() : 0;
}
llvm::Optional<StringRef> TypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}
llvm::Optional<StringRef> TypeDef::getPrinterCode() const {
return def->getValueAsOptionalCode("printer");
}
llvm::Optional<StringRef> TypeDef::getParserCode() const {
return def->getValueAsOptionalCode("parser");
}
bool TypeDef::genAccessors() const {
return def->getValueAsBit("genAccessors");
}
bool TypeDef::genVerifyInvariantsDecl() const {
return def->getValueAsBit("genVerifyInvariantsDecl");
}
llvm::Optional<StringRef> TypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
llvm::ArrayRef<llvm::SMLoc> TypeDef::getLoc() const { return def->getLoc(); }
bool TypeDef::operator==(const TypeDef &other) const {
return def == other.def;
}
bool TypeDef::operator<(const TypeDef &other) const {
return getName() < other.getName();
}
StringRef TypeParameter::getName() const {
return def->getArgName(num)->getValue();
}
llvm::Optional<StringRef> TypeParameter::getAllocator() const {
llvm::Init *parameterType = def->getArg(num);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return llvm::Optional<StringRef>();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator");
if (llvm::CodeInit *ci = dyn_cast<llvm::CodeInit>(code->getValue()))
return ci->getValue();
if (isa<llvm::UnsetInit>(code->getValue()))
return llvm::Optional<StringRef>();
llvm::PrintFatalError(
typeParameter->getDef()->getLoc(),
"Record `" + def->getArgName(num)->getValue() +
"', field `printer' does not have a code initializer!");
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from TypeParameter\n");
}
StringRef TypeParameter::getCppType() const {
auto *parameterType = def->getArg(num);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType))
return typeParameter->getDef()->getValueAsString("cppType");
llvm::PrintFatalError(
"Parameters DAG arguments must be either strings or defs "
"which inherit from TypeParameter\n");
}
llvm::Optional<StringRef> TypeParameter::getDescription() const {
auto *parameterType = def->getArg(num);
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *desc = typeParameter->getDef()->getValue("description");
if (llvm::StringInit *ci = dyn_cast<llvm::StringInit>(desc->getValue()))
return ci->getValue();
}
return llvm::Optional<StringRef>();
}
StringRef TypeParameter::getSyntax() const {
auto *parameterType = def->getArg(num);
if (auto *stringType = dyn_cast<llvm::StringInit>(parameterType))
return stringType->getValue();
if (auto *typeParameter = dyn_cast<llvm::DefInit>(parameterType)) {
const auto *syntax = typeParameter->getDef()->getValue("syntax");
if (syntax && isa<llvm::StringInit>(syntax->getValue()))
return dyn_cast<llvm::StringInit>(syntax->getValue())->getValue();
return getCppType();
}
llvm::PrintFatalError("Parameters DAG arguments must be either strings or "
"defs which inherit from TypeParameter");
}

View File

@@ -9,6 +9,12 @@ mlir_tablegen(TestTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRTestInterfaceIncGen)
set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td)
mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls)
mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRTestDefIncGen)
set(LLVM_TARGET_DEFINITIONS TestOps.td)
mlir_tablegen(TestOps.h.inc -gen-op-decls)
mlir_tablegen(TestOps.cpp.inc -gen-op-defs)
@@ -25,11 +31,13 @@ add_mlir_library(MLIRTestDialect
TestDialect.cpp
TestPatterns.cpp
TestTraits.cpp
TestTypes.cpp
EXCLUDE_FROM_LIBMLIR
DEPENDS
MLIRTestInterfaceIncGen
MLIRTestDefIncGen
MLIRTestOpsIncGen
LINK_LIBS PUBLIC

View File

@@ -141,16 +141,23 @@ void TestDialect::initialize() {
>();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>();
addTypes<TestType, TestRecursiveType>();
addTypes<TestType, TestRecursiveType,
#define GET_TYPEDEF_LIST
#include "TestTypeDefs.cpp.inc"
>();
allowUnknownOperations();
}
static Type parseTestType(DialectAsmParser &parser,
static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser,
llvm::SetVector<Type> &stack) {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
auto genType = generatedTypeParser(ctxt, parser, typeTag);
if (genType != Type())
return genType;
if (typeTag == "test_type")
return TestType::get(parser.getBuilder().getContext());
@@ -174,7 +181,7 @@ static Type parseTestType(DialectAsmParser &parser,
if (failed(parser.parseComma()))
return Type();
stack.insert(rec);
Type subtype = parseTestType(parser, stack);
Type subtype = parseTestType(ctxt, parser, stack);
stack.pop_back();
if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype)))
return Type();
@@ -184,11 +191,13 @@ static Type parseTestType(DialectAsmParser &parser,
Type TestDialect::parseType(DialectAsmParser &parser) const {
llvm::SetVector<Type> stack;
return parseTestType(parser, stack);
return parseTestType(getContext(), parser, stack);
}
static void printTestType(Type type, DialectAsmPrinter &printer,
llvm::SetVector<Type> &stack) {
if (succeeded(generatedTypePrinter(type, printer)))
return;
if (type.isa<TestType>()) {
printer << "test_type";
return;

View File

@@ -0,0 +1,150 @@
//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TableGen data type definitions for Test dialect.
//
//===----------------------------------------------------------------------===//
#ifndef TEST_TYPEDEFS
#define TEST_TYPEDEFS
// To get the test dialect def.
include "TestOps.td"
// All of the types will extend this class.
class Test_Type<string name> : TypeDef<Test_Dialect, name> { }
def SimpleTypeA : Test_Type<"SimpleA"> {
let mnemonic = "smpla";
let printer = [{ $_printer << "smpla"; }];
let parser = [{ return get($_ctxt); }];
}
// A more complex parameterized type.
def CompoundTypeA : Test_Type<"CompoundA"> {
let mnemonic = "cmpnd_a";
// List of type parameters.
let parameters = (
ins
"int":$widthOfSomething,
"::mlir::Type":$oneType,
// This is special syntax since ArrayRefs require allocation in the
// constructor.
ArrayRefParameter<
"int", // The parameter C++ type.
"An example of an array of ints" // Parameter description.
>: $arrayOfInts
);
let extraClassDeclaration = [{
struct SomeCppStruct {};
}];
}
// An example of how one could implement a standard integer.
def IntegerType : Test_Type<"TestInteger"> {
let mnemonic = "int";
let genVerifyInvariantsDecl = 1;
let parameters = (
ins
// SignednessSemantics is defined below.
"::mlir::TestIntegerType::SignednessSemantics":$signedness,
"unsigned":$width
);
// We define the printer inline.
let printer = [{
$_printer << "int<";
printSignedness($_printer, getImpl()->signedness);
$_printer << ", " << getImpl()->width << ">";
}];
// The parser is defined here also.
let parser = [{
if (parser.parseLess()) return Type();
SignednessSemantics signedness;
if (parseSignedness($_parser, signedness)) return mlir::Type();
if ($_parser.parseComma()) return Type();
int width;
if ($_parser.parseInteger(width)) return Type();
if ($_parser.parseGreater()) return Type();
return get(ctxt, signedness, width);
}];
// Any extra code one wants in the type's class declaration.
let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
};
/// This extra function is necessary since it doesn't include signedness
static IntegerType getChecked(unsigned width, Location location);
/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
}];
}
// A parent type for any type which is just a list of fields (e.g. structs,
// unions).
class FieldInfo_Type<string name> : Test_Type<name> {
let parameters = (
ins
// An ArrayRef of something which requires allocation in the storage
// constructor.
ArrayRefOfSelfAllocationParameter<
"::mlir::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
"Models struct fields">: $fields
);
// Prints the type in this format:
// struct<[{field1Name, field1Type}, {field2Name, field2Type}]
let printer = [{
$_printer << "struct" << "<";
for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
const auto& field = getImpl()->fields[i];
$_printer << "{" << field.name << "," << field.type << "}";
if (i < getImpl()->fields.size() - 1)
$_printer << ",";
}
$_printer << ">";
}];
// Parses the above format
let parser = [{
llvm::SmallVector<FieldInfo, 4> parameters;
if ($_parser.parseLess()) return Type();
while (mlir::succeeded($_parser.parseOptionalLBrace())) {
StringRef name;
if ($_parser.parseKeyword(&name)) return Type();
if ($_parser.parseComma()) return Type();
Type type;
if ($_parser.parseType(type)) return Type();
if ($_parser.parseRBrace()) return Type();
parameters.push_back(FieldInfo {name, type});
if ($_parser.parseOptionalComma()) break;
}
if ($_parser.parseGreater()) return Type();
return get($_ctxt, parameters);
}];
}
def StructType : FieldInfo_Type<"Struct"> {
let mnemonic = "struct";
}
#endif // TEST_TYPEDEFS

View File

@@ -0,0 +1,117 @@
//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains types defined by the TestDialect for testing various
// features of MLIR.
//
//===----------------------------------------------------------------------===//
#include "TestTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
// Custom parser for SignednessSemantics.
static ParseResult
parseSignedness(DialectAsmParser &parser,
TestIntegerType::SignednessSemantics &result) {
StringRef signStr;
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&signStr))
return failure();
if (signStr.compare_lower("u") || signStr.compare_lower("unsigned"))
result = TestIntegerType::SignednessSemantics::Unsigned;
else if (signStr.compare_lower("s") || signStr.compare_lower("signed"))
result = TestIntegerType::SignednessSemantics::Signed;
else if (signStr.compare_lower("n") || signStr.compare_lower("none"))
result = TestIntegerType::SignednessSemantics::Signless;
else
return parser.emitError(loc, "expected signed, unsigned, or none");
return success();
}
// Custom printer for SignednessSemantics.
static void printSignedness(DialectAsmPrinter &printer,
const TestIntegerType::SignednessSemantics &ss) {
switch (ss) {
case TestIntegerType::SignednessSemantics::Unsigned:
printer << "unsigned";
break;
case TestIntegerType::SignednessSemantics::Signed:
printer << "signed";
break;
case TestIntegerType::SignednessSemantics::Signless:
printer << "none";
break;
}
}
Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) {
int widthOfSomething;
Type oneType;
SmallVector<int, 4> arrayOfInts;
if (parser.parseLess() || parser.parseInteger(widthOfSomething) ||
parser.parseComma() || parser.parseType(oneType) || parser.parseComma() ||
parser.parseLSquare())
return Type();
int i;
while (!*parser.parseOptionalInteger(i)) {
arrayOfInts.push_back(i);
if (parser.parseOptionalComma())
break;
}
if (parser.parseRSquare() || parser.parseGreater())
return Type();
return get(ctxt, widthOfSomething, oneType, arrayOfInts);
}
void CompoundAType::print(DialectAsmPrinter &printer) const {
printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType()
<< ", [";
auto intArray = getArrayOfInts();
llvm::interleaveComma(intArray, printer);
printer << "]>";
}
// The functions don't need to be in the header file, but need to be in the mlir
// namespace. Declare them here, then define them immediately below. Separating
// the declaration and definition adheres to the LLVM coding standards.
namespace mlir {
// FieldInfo is used as part of a parameter, so equality comparison is
// compulsory.
static bool operator==(const FieldInfo &a, const FieldInfo &b);
// FieldInfo is used as part of a parameter, so a hash will be computed.
static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT
} // namespace mlir
// FieldInfo is used as part of a parameter, so equality comparison is
// compulsory.
static bool mlir::operator==(const FieldInfo &a, const FieldInfo &b) {
return a.name == b.name && a.type == b.type;
}
// FieldInfo is used as part of a parameter, so a hash will be computed.
static llvm::hash_code mlir::hash_value(const FieldInfo &fi) { // NOLINT
return llvm::hash_combine(fi.name, fi.type);
}
// Example type validity checker.
LogicalResult TestIntegerType::verifyConstructionInvariants(
Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) {
if (width > 8)
return failure();
return success();
}
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.cpp.inc"

View File

@@ -14,11 +14,35 @@
#ifndef MLIR_TESTTYPES_H
#define MLIR_TESTTYPES_H
#include <tuple>
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
namespace mlir {
/// FieldInfo represents a field in the StructType data type. It is used as a
/// parameter in TestTypeDefs.td.
struct FieldInfo {
StringRef name;
Type type;
// Custom allocation called from generated constructor code
FieldInfo allocateInto(TypeStorageAllocator &alloc) const {
return FieldInfo{alloc.copyInto(name), type};
}
};
} // namespace mlir
#define GET_TYPEDEF_CLASSES
#include "TestTypeDefs.h.inc"
namespace mlir {
#include "TestTypeInterfaces.h.inc"
/// This class is a simple test type that uses a generated interface.

View File

@@ -0,0 +1,24 @@
// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s
//////////////
// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir'
// CHECK: @simpleA(%arg0: !test.smpla)
func @simpleA(%A : !test.smpla) -> () {
return
}
// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>)
func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
return
}
// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
return
}
// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
return
}

View File

@@ -0,0 +1,132 @@
// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
include "mlir/IR/OpBase.td"
// DECL: #ifdef GET_TYPEDEF_CLASSES
// DECL: #undef GET_TYPEDEF_CLASSES
// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);
// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer);
// DEF: #ifdef GET_TYPEDEF_LIST
// DEF: #undef GET_TYPEDEF_LIST
// DEF: ::mlir::test::SimpleAType,
// DEF: ::mlir::test::CompoundAType,
// DEF: ::mlir::test::IndexType,
// DEF: ::mlir::test::SingleParameterType,
// DEF: ::mlir::test::IntegerType
// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic)
// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser);
// DEF return ::mlir::Type();
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
// DEF-NOT: TestDialect
let name = "TestDialect";
let cppNamespace = "::mlir::test";
}
class TestType<string name> : TypeDef<Test_Dialect, name> { }
def A_SimpleTypeA : TestType<"SimpleA"> {
// DECL: class SimpleAType: public ::mlir::Type
}
// A more complex parameterized type
def B_CompoundTypeA : TestType<"CompoundA"> {
let summary = "A more complex parameterized type";
let description = "This type is to test a reasonably complex type";
let mnemonic = "cmpnd_a";
let parameters = (
ins
"int":$widthOfSomething,
"::mlir::test::SimpleTypeA": $exampleTdType,
"SomeCppStruct": $exampleCppType,
ArrayRefParameter<"int", "Matrix dimensions">:$dims
);
let genVerifyInvariantsDecl = 1;
// DECL-LABEL: class CompoundAType: public ::mlir::Type
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
// DECL: int getWidthOfSomething() const;
// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const;
// DECL: SomeCppStruct getExampleCppType() const;
}
def C_IndexType : TestType<"Index"> {
let mnemonic = "index";
let parameters = (
ins
StringRefParameter<"Label for index">:$label
);
// DECL-LABEL: class IndexType: public ::mlir::Type
// DECL: static ::llvm::StringRef getMnemonic() { return "index"; }
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;
}
def D_SingleParameterType : TestType<"SingleParameter"> {
let parameters = (
ins
"int": $num
);
// DECL-LABEL: struct SingleParameterTypeStorage;
// DECL-LABEL: class SingleParameterType
// DECL-NEXT: detail::SingleParameterTypeStorage
}
def E_IntegerType : TestType<"Integer"> {
let mnemonic = "int";
let genVerifyInvariantsDecl = 1;
let parameters = (
ins
"SignednessSemantics":$signedness,
TypeParameter<"unsigned", "Bitwdith of integer">:$width
);
// DECL-LABEL: IntegerType: public ::mlir::Type
let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics {
Signless, /// No signedness semantics
Signed, /// Signed integer
Unsigned, /// Unsigned integer
};
/// This extra function is necessary since it doesn't include signedness
static IntegerType getChecked(unsigned width, Location location);
/// Return true if this is a signless integer type.
bool isSignless() const { return getSignedness() == Signless; }
/// Return true if this is a signed integer type.
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
}];
// DECL: /// Signedness semantics.
// DECL-NEXT: enum SignednessSemantics {
// DECL-NEXT: Signless, /// No signedness semantics
// DECL-NEXT: Signed, /// Signed integer
// DECL-NEXT: Unsigned, /// Unsigned integer
// DECL-NEXT: };
// DECL: /// This extra function is necessary since it doesn't include signedness
// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location);
// DECL: /// Return true if this is a signless integer type.
// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; }
// DECL-NEXT: /// Return true if this is a signed integer type.
// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; }
// DECL-NEXT: /// Return true if this is an unsigned integer type.
// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; }
}

View File

@@ -20,6 +20,7 @@ add_tablegen(mlir-tblgen MLIR
RewriterGen.cpp
SPIRVUtilsGen.cpp
StructsGen.cpp
TypeDefGen.cpp
)
set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning")

View File

@@ -15,6 +15,7 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -23,6 +24,8 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include <set>
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
@@ -155,12 +158,67 @@ static void emitTypeDoc(const Type &type, raw_ostream &os) {
os << "\n";
}
//===----------------------------------------------------------------------===//
// TypeDef Documentation
//===----------------------------------------------------------------------===//
/// Emit the assembly format of a type.
static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
td.getParameters(parameters);
if (parameters.size() == 0) {
os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic()
<< "`\n";
return;
}
os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "."
<< td.getMnemonic() << "<\n";
for (auto *it = parameters.begin(), *e = parameters.end(); it < e; ++it) {
os << " " << it->getSyntax();
if (it < parameters.end() - 1)
os << ",";
os << " # " << it->getName() << "\n";
}
os << ">\n```\n";
}
static void emitTypeDefDoc(TypeDef td, raw_ostream &os) {
os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName());
// Emit the summary, syntax, and description if present.
if (td.hasSummary())
os << "\n" << td.getSummary() << "\n";
if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" &&
td.getParserCode() && *td.getParserCode() == "")
emitTypeAssemblyFormat(td, os);
if (td.hasDescription())
mlir::tblgen::emitDescription(td.getDescription(), os);
// Emit attribute documentation.
SmallVector<TypeParameter, 4> parameters;
td.getParameters(parameters);
if (parameters.size() != 0) {
os << "\n#### Type parameters:\n\n";
os << "| Parameter | C++ type | Description |\n"
<< "| :-------: | :-------: | ----------- |\n";
for (const auto &it : parameters) {
auto desc = it.getDescription();
os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | "
<< (desc ? *desc : "") << " |\n";
}
}
os << "\n";
}
//===----------------------------------------------------------------------===//
// Dialect Documentation
//===----------------------------------------------------------------------===//
static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
ArrayRef<Type> types, raw_ostream &os) {
ArrayRef<Type> types, ArrayRef<TypeDef> typeDefs,
raw_ostream &os) {
os << "# '" << dialect.getName() << "' Dialect\n\n";
emitIfNotEmpty(dialect.getSummary(), os);
emitIfNotEmpty(dialect.getDescription(), os);
@@ -169,7 +227,7 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
// TODO: Add link between use and def for types
if (!types.empty()) {
os << "## Type definition\n\n";
os << "## Type constraint definition\n\n";
for (const Type &type : types)
emitTypeDoc(type, os);
}
@@ -179,28 +237,43 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<Operator> ops,
for (const Operator &op : ops)
emitOpDoc(op, os);
}
if (!typeDefs.empty()) {
os << "## Type definition\n\n";
for (const TypeDef &td : typeDefs)
emitTypeDefDoc(td, os);
}
}
static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op");
const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType");
const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
std::set<Dialect> dialectsWithDocs;
std::map<Dialect, std::vector<Operator>> dialectOps;
std::map<Dialect, std::vector<Type>> dialectTypes;
std::map<Dialect, std::vector<TypeDef>> dialectTypeDefs;
for (auto *opDef : opDefs) {
Operator op(opDef);
dialectOps[op.getDialect()].push_back(op);
dialectsWithDocs.insert(op.getDialect());
}
for (auto *typeDef : typeDefs) {
Type type(typeDef);
if (auto dialect = type.getDialect())
dialectTypes[dialect].push_back(type);
}
for (auto *typeDef : typeDefDefs) {
TypeDef type(typeDef);
dialectTypeDefs[type.getDialect()].push_back(type);
dialectsWithDocs.insert(type.getDialect());
}
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
for (const auto &dialectWithOps : dialectOps)
emitDialectDoc(dialectWithOps.first, dialectWithOps.second,
dialectTypes[dialectWithOps.first], os);
for (auto dialect : dialectsWithDocs)
emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect],
dialectTypeDefs[dialect], os);
}
//===----------------------------------------------------------------------===//

View File

@@ -0,0 +1,561 @@
//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TypeDefGen uses the description of typeDefs to generate C++ definitions.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/TypeDef.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
#define DEBUG_TYPE "mlir-tblgen-typedefgen"
using namespace mlir;
using namespace mlir::tblgen;
static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*");
static llvm::cl::opt<std::string>
selectedDialect("typedefs-dialect",
llvm::cl::desc("Gen types for this dialect"),
llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated);
/// Find all the TypeDefs for the specified dialect. If no dialect specified and
/// can only find one dialect's types, use that.
static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper,
SmallVectorImpl<TypeDef> &typeDefs) {
auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef");
auto defs = llvm::map_range(
recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); });
if (defs.empty())
return;
StringRef dialectName;
if (selectedDialect.getNumOccurrences() == 0) {
if (defs.empty())
return;
llvm::SmallSet<Dialect, 4> dialects;
for (const TypeDef &typeDef : defs)
dialects.insert(typeDef.getDialect());
if (dialects.size() != 1)
llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must "
"select one via '--typedefs-dialect'");
dialectName = (*dialects.begin()).getName();
} else if (selectedDialect.getNumOccurrences() == 1) {
dialectName = selectedDialect.getValue();
} else {
llvm::PrintFatalError("Cannot select multiple dialects for which to "
"generate types via '--typedefs-dialect'.");
}
for (const TypeDef &typeDef : defs)
if (typeDef.getDialect().getName().equals(dialectName))
typeDefs.push_back(typeDef);
}
namespace {
/// Pass an instance of this class to llvm::formatv() to emit a comma separated
/// list of parameters in the format by 'EmitFormat'.
class TypeParamCommaFormatter : public llvm::detail::format_adapter {
public:
/// Choose the output format
enum EmitFormat {
/// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name,
/// [...]".
TypeNamePairs,
/// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
/// [...]".
TypeNamePairsPrependComma,
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
TypeNameInitializer
};
TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
: emitFormat(emitFormat), params(params) {}
/// llvm::formatv will call this function when using an instance as a
/// replacement value.
void format(raw_ostream &os, StringRef options) {
if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
os << ", ";
switch (emitFormat) {
case EmitFormat::TypeNamePairs:
case EmitFormat::TypeNamePairsPrependComma:
interleaveComma(params, os,
[&](const TypeParameter &p) { emitTypeNamePair(p, os); });
break;
case EmitFormat::TypeNameInitializer:
interleaveComma(params, os, [&](const TypeParameter &p) {
emitTypeNameInitializer(p, os);
});
break;
}
}
private:
// Emit "paramType paramName".
static void emitTypeNamePair(const TypeParameter &param, raw_ostream &os) {
os << param.getCppType() << " " << param.getName();
}
// Emit "paramName(paramName)"
void emitTypeNameInitializer(const TypeParameter &param, raw_ostream &os) {
os << param.getName() << "(" << param.getName() << ")";
}
EmitFormat emitFormat;
ArrayRef<TypeParameter> params;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// GEN: TypeDef declarations
//===----------------------------------------------------------------------===//
/// The code block for the start of a typeDef class declaration -- singleton
/// case.
///
/// {0}: The name of the typeDef class.
static const char *const typeDefDeclSingletonBeginStr = R"(
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
)";
/// The code block for the start of a typeDef class declaration -- parametric
/// case.
///
/// {0}: The name of the typeDef class.
/// {1}: The typeDef storage class namespace.
/// {2}: The storage class name.
/// {3}: The list of parameters with types.
static const char *const typeDefDeclParametricBeginStr = R"(
namespace {1} {
struct {2};
}
class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type,
{1}::{2}> {{
public:
/// Inherit some necessary constructors from 'TypeBase'.
using Base::Base;
)";
/// The snippet for print/parse.
static const char *const typeDefParsePrint = R"(
static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
void print(::mlir::DialectAsmPrinter& printer) const;
)";
/// The code block for the verifyConstructionInvariants and getChecked.
///
/// {0}: List of parameters, parameters style.
/// {1}: C++ type class name.
static const char *const typeDefDeclVerifyStr = R"(
static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
static {1} getChecked(Location loc{0});
)";
/// Generate the declaration for the given typeDef class.
static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> params;
typeDef.getParameters(params);
// Emit the beginning string template: either the singleton or parametric
// template.
if (typeDef.getNumParameters() == 0)
os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(),
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
else
os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(),
typeDef.getStorageNamespace(), typeDef.getStorageClassName());
// Emit the extra declarations first in case there's a type definition in
// there.
if (Optional<StringRef> extraDecl = typeDef.getExtraDecls())
os << *extraDecl << "\n";
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
// Emit the verify invariants declaration.
if (typeDef.genVerifyInvariantsDecl())
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
typeDef.getCppClassName());
// Emit the mnenomic, if specified.
if (auto mnenomic = typeDef.getMnemonic()) {
os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic
<< "\"; }\n";
// If mnemonic specified, emit print/parse declarations.
os << typeDefParsePrint;
}
if (typeDef.genAccessors()) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
for (TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name);
}
}
// End the typeDef decl.
os << " };\n";
}
/// Main entry point for decls.
static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("TypeDef Declarations", os);
SmallVector<TypeDef, 16> typeDefs;
findAllTypeDefs(recordKeeper, typeDefs);
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
if (typeDefs.size() > 0) {
NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect());
// Well known print/parse dispatch function declarations. These are called
// from Dialect::parseType() and Dialect::printType() methods.
os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n";
os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
"::mlir::DialectAsmPrinter& printer);\n";
os << "\n";
// Declare all the type classes first (in case they reference each other).
for (const TypeDef &typeDef : typeDefs)
os << " class " << typeDef.getCppClassName() << ";\n";
// Declare all the typedefs.
for (const TypeDef &typeDef : typeDefs)
emitTypeDefDecl(typeDef, os);
}
return false;
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef list
//===----------------------------------------------------------------------===//
static void emitTypeDefList(SmallVectorImpl<TypeDef> &typeDefs,
raw_ostream &os) {
IfDefScope scope("GET_TYPEDEF_LIST", os);
for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) {
os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName();
if (i < typeDefs.end() - 1)
os << ",\n";
else
os << "\n";
}
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef definitions
//===----------------------------------------------------------------------===//
/// Beginning of storage class.
/// {0}: Storage class namespace.
/// {1}: Storage class c++ name.
/// {2}: Parameters parameters.
/// {3}: Parameter initialzer string.
/// {4}: Parameter name list.
/// {5}: Parameter types.
static const char *const typeDefStorageClassBegin = R"(
namespace {0} {{
struct {1} : public ::mlir::TypeStorage {{
{1} ({2})
: {3} {{ }
/// The hash key for this storage is a pair of the integer and type params.
using KeyTy = std::tuple<{5}>;
/// Define the comparison function for the key type.
bool operator==(const KeyTy &key) const {{
return key == KeyTy({4});
}
)";
/// The storage class' constructor template.
/// {0}: storage class name.
static const char *const typeDefStorageClassConstructorBegin = R"(
/// Define a construction method for creating a new instance of this storage.
static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{
)";
/// The storage class' constructor return template.
/// {0}: storage class name.
/// {1}: list of parameters.
static const char *const typeDefStorageClassConstructorReturn = R"(
return new (allocator.allocate<{0}>())
{0}({1});
}
)";
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator");
for (TypeParameter &parameter : parameters) {
auto allocCode = parameter.getAllocator();
if (allocCode) {
fmtCtxt.withSelf(parameter.getName());
fmtCtxt.addSubst("_dst", parameter.getName());
os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n";
}
}
}
/// Emit the storage class code for type 'typeDef'.
/// This includes (in-order):
/// 1) typeDefStorageClassBegin, which includes:
/// - The class constructor.
/// - The KeyTy definition.
/// - The equality (==) operator.
/// 2) The hashKey method.
/// 3) The construct method.
/// 4) The list of parameters as the storage class member variables.
static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
// Initialize a bunch of variables to be used later on.
auto parameterNames = map_range(
parameters, [](TypeParameter parameter) { return parameter.getName(); });
auto parameterTypes = map_range(parameters, [](TypeParameter parameter) {
return parameter.getCppType();
});
auto parameterList = join(parameterNames, ", ");
auto parameterTypeList = join(parameterTypes, ", ");
// 1) Emit most of the storage class up until the hashKey body.
os << formatv(
typeDefStorageClassBegin, typeDef.getStorageNamespace(),
typeDef.getStorageClassName(),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
parameterList, parameterTypeList);
// 2) Emit the haskKey method.
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
// Extract each parameter from the key.
for (size_t i = 0, e = parameters.size(); i < e; ++i)
os << formatv(" const auto &{0} = std::get<{1}>(key);\n",
parameters[i].getName(), i);
// Then combine them all. This requires all the parameters types to have a
// hash_value defined.
os << " return ::llvm::hash_combine(";
interleaveComma(parameterNames, os);
os << ");\n";
os << " }\n";
// 3) Emit the construct method.
if (typeDef.hasStorageCustomConstructor())
// If user wants to build the storage constructor themselves, declare it
// here and then they can write the definition elsewhere.
os << " static " << typeDef.getStorageClassName()
<< " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy "
"&key);\n";
else {
// If not, autogenerate one.
// First, unbox the parameters.
os << formatv(typeDefStorageClassConstructorBegin,
typeDef.getStorageClassName());
for (size_t i = 0; i < parameters.size(); ++i) {
os << formatv(" auto {0} = std::get<{1}>(key);\n",
parameters[i].getName(), i);
}
// Second, reassign the parameter variables with allocation code, if it's
// specified.
emitParameterAllocationCode(typeDef, os);
// Last, return an allocated copy.
os << formatv(typeDefStorageClassConstructorReturn,
typeDef.getStorageClassName(), parameterList);
}
// 4) Emit the parameters as storage class members.
for (auto parameter : parameters) {
os << " " << parameter.getCppType() << " " << parameter.getName()
<< ";\n";
}
os << " };\n";
os << "} // namespace " << typeDef.getStorageNamespace() << "\n";
}
/// Emit the parser and printer for a particular type, if they're specified.
void emitParserPrinter(TypeDef typeDef, raw_ostream &os) {
// Emit the printer code, if specified.
if (auto printerCode = typeDef.getPrinterCode()) {
// Both the mnenomic and printerCode must be defined (for parity with
// parserCode).
os << "void " << typeDef.getCppClassName()
<< "::print(::mlir::DialectAsmPrinter& printer) const {\n";
if (*printerCode == "") {
// If no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
typeDef.getName() +
": printer (if specified) must have non-empty code");
}
auto fmtCtxt = FmtContext().addSubst("_printer", "printer");
os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n";
}
// emit a parser, if specified.
if (auto parserCode = typeDef.getParserCode()) {
// The mnenomic must be defined so the dispatcher knows how to dispatch.
os << "::mlir::Type " << typeDef.getCppClassName()
<< "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& "
"parser) "
"{\n";
if (*parserCode == "") {
// if no code specified, emit error.
PrintFatalError(typeDef.getLoc(),
typeDef.getName() +
": parser (if specified) must have non-empty code");
}
auto fmtCtxt =
FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt");
os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n";
}
}
/// Print all the typedef-specific definition code.
static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
NamespaceEmitter ns(os, typeDef.getDialect());
SmallVector<TypeParameter, 4> parameters;
typeDef.getParameters(parameters);
// Emit the storage class, if requested and necessary.
if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0)
emitStorageClass(typeDef, os);
os << llvm::formatv(
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
" return Base::get(ctxt",
typeDef.getCppClassName(),
TypeParamCommaFormatter(
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
parameters));
for (TypeParameter &param : parameters)
os << ", " << param.getName();
os << ");\n}\n";
// Emit the parameter accessors.
if (typeDef.genAccessors())
for (const TypeParameter &parameter : parameters) {
SmallString<16> name = parameter.getName();
name[0] = llvm::toUpper(name[0]);
os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n",
parameter.getCppType(), name, parameter.getName(),
typeDef.getCppClassName());
}
// If mnemonic is specified maybe print definitions for the parser and printer
// code, if they're specified.
if (typeDef.getMnemonic())
emitParserPrinter(typeDef, os);
}
/// Emit the dialect printer/parser dispatcher. User's code should call these
/// functions from their dialect's print/parse methods.
static void emitParsePrintDispatch(SmallVectorImpl<TypeDef> &typeDefs,
raw_ostream &os) {
if (typeDefs.size() == 0)
return;
const Dialect &dialect = typeDefs.begin()->getDialect();
NamespaceEmitter ns(os, dialect);
// The parser dispatch is just a list of if-elses, matching on the mnemonic
// and calling the class's parse function.
os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, "
"::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n";
for (const TypeDef &typeDef : typeDefs)
if (typeDef.getMnemonic())
os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return "
"{0}::{1}::parse(ctxt, parser);\n",
typeDef.getDialect().getCppNamespace(),
typeDef.getCppClassName());
os << " return ::mlir::Type();\n";
os << "}\n\n";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, "
"::mlir::DialectAsmPrinter& printer) {\n"
<< " ::mlir::LogicalResult found = ::mlir::success();\n"
<< " ::llvm::TypeSwitch<::mlir::Type>(type)\n";
for (auto typeDef : typeDefs)
if (typeDef.getMnemonic())
os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ "
"t.dyn_cast<{0}::{1}>().print(printer); })\n",
typeDef.getDialect().getCppNamespace(),
typeDef.getCppClassName());
os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); "
"});\n"
<< " return found;\n"
<< "}\n\n";
}
/// Entry point for typedef definitions.
static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper,
raw_ostream &os) {
emitSourceFileHeader("TypeDef Definitions", os);
SmallVector<TypeDef, 16> typeDefs;
findAllTypeDefs(recordKeeper, typeDefs);
emitTypeDefList(typeDefs, os);
IfDefScope scope("GET_TYPEDEF_CLASSES", os);
emitParsePrintDispatch(typeDefs, os);
for (auto typeDef : typeDefs)
emitTypeDefDef(typeDef, os);
return false;
}
//===----------------------------------------------------------------------===//
// GEN: TypeDef registration hooks
//===----------------------------------------------------------------------===//
static mlir::GenRegistration
genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return emitTypeDefDefs(records, os);
});
static mlir::GenRegistration
genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return emitTypeDefDecls(records, os);
});