NFC: refactor ODS builder generation

Previously we use one single method with lots of branches to
generate multiple builders. This makes the method difficult
to follow and modify. This CL splits the method into multiple
dedicated ones, by extracting common logic into helper methods
while leaving logic specific to each builder in their own
methods.

PiperOrigin-RevId: 261011082
This commit is contained in:
Lei Zhang
2019-07-31 15:30:46 -07:00
committed by A. Unique TensorFlower
parent cf66d7bb74
commit e44ba1f8bf
3 changed files with 201 additions and 159 deletions

View File

@@ -19,7 +19,7 @@ def OpA : NS_Op<"one_normal_operand_op", []> {
// CHECK: void OpA::build
// CHECK-SAME: Value *input
// CHECK: tblgen_state->operands.push_back(input);
// CHECK: tblgen_state->addOperands(input);
// CHECK: void OpA::build
// CHECK-SAME: ArrayRef<Value *> operands
@@ -56,5 +56,5 @@ def OpD : NS_Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]>
// CHECK-LABEL: OpD::build
// CHECK-NEXT: tblgen_state->addOperands(input1);
// CHECK-NEXT: tblgen_state->operands.push_back(input2);
// CHECK-NEXT: tblgen_state->addOperands(input2);
// CHECK-NEXT: tblgen_state->addOperands(input3);

View File

@@ -24,7 +24,7 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
// CHECK-LABEL: OpB definitions
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
// CHECK: tblgen_state->types.push_back(y);
// CHECK: tblgen_state->addTypes(y);
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x)
// CHECK: tblgen_state->addTypes({x->getType()});
@@ -34,9 +34,9 @@ def OpC : NS_Op<"three_normal_result_op", []> {
// CHECK-LABEL: OpC definitions
// CHECK: void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
// CHECK-NEXT: tblgen_state->types.push_back(x)
// CHECK-NEXT: tblgen_state->types.push_back(resultType1)
// CHECK-NEXT: tblgen_state->types.push_back(z)
// CHECK-NEXT: tblgen_state->addTypes(x)
// CHECK-NEXT: tblgen_state->addTypes(resultType1)
// CHECK-NEXT: tblgen_state->addTypes(z)
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
@@ -74,7 +74,7 @@ def OpG : NS_Op<"one_normal_and_one_variadic_result_op", []> {
// CHECK-LABEL: OpG definitions
// CHECK: void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
// CHECK-NEXT: tblgen_state->types.push_back(x);
// CHECK-NEXT: tblgen_state->addTypes(x);
// CHECK-NEXT: tblgen_state->addTypes(y);
// CHECK: void OpG::build
@@ -94,7 +94,7 @@ def OpI : NS_Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]>
// CHECK-LABEL: OpI::build
// CHECK-NEXT: tblgen_state->addTypes(output1);
// CHECK-NEXT: tblgen_state->types.push_back(output2);
// CHECK-NEXT: tblgen_state->addTypes(output2);
// CHECK-NEXT: tblgen_state->addTypes(output3);
// Test that if the only operand is variadic, we acess the first value in the

View File

@@ -460,6 +460,9 @@ private:
void emitDecl(raw_ostream &os);
void emitDef(raw_ostream &os);
// Generates the `getOperationName` method for this op.
void genOpNameGetter();
// Generates getters for the attributes.
void genAttrGetters();
@@ -472,9 +475,40 @@ private:
// Generates getters for named regions.
void genNamedRegionGetters();
// Generates builder method for the operation.
// Generates builder methods for the operation.
void genBuilder();
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. This build() method also requires specifying
// result types for all results.
void genSeparateParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. This build() method uses first operand's type
// as all result's types.
void genUseOperandAsResultTypeBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. This build() method uses first attribute's type
// as all result's types.
void genUseAttrAsResultTypeBuilder();
// Generates the build() method that takes all result types collectively as
// one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder();
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list. If `includeResultTypes`
// is true then `paramList` will also contain the parameters for all results
// and `resultTypeNames` will be populated with the parameter name for each
// result type.
void buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
bool includeResultTypes);
// Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body);
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
@@ -503,15 +537,7 @@ private:
// Generates the traits used by the object.
void genTraits();
// Generates the build() method that takes each result-type/operand/attribute
// as a stand-alone parameter. Using the first operand's type as all result
// types if `useOperandType` is true. Using the first attribute's type as all
// result types if `useAttrType` true. Don't set `useOperandType` and
// `useAttrType` at the same time.
void genStandaloneParamBuilder(bool useOperandType, bool useAttrType);
void genOpNameGetter();
private:
// The TableGen record for this op.
// TODO(antiagainst,zinenko): OpEmitter should not have a Record directly,
// it should rather go through the Operator for better abstraction.
@@ -736,132 +762,67 @@ void OpEmitter::genNamedRegionGetters() {
}
}
void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
bool useAttrType) {
if (useOperandType && useAttrType) {
PrintFatalError(def.getLoc(),
"Op definition has both 'SameOperandsAndResultType' and "
"'FirstAttrIsResultType' trait specified.");
}
auto numResults = op.getNumResults();
void OpEmitter::genSeparateParamBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
resultNames.reserve(numResults);
std::string paramList = "Builder *, OperationState *";
paramList.append(builderOpState);
// Emit parameters for all return types
if (!useOperandType && !useAttrType) {
for (int i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = result.name;
if (resultName.empty())
resultName = formatv("resultType{0}", i);
paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
paramList.append(resultName);
resultNames.emplace_back(std::move(resultName));
}
}
// Emit parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
// TODO(antiagainst): Support default initializer for attributes
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
paramList.append(", ");
if (attr.isOptional())
paramList.append("/*optional*/");
paramList.append(
(attr.getStorageType() + Twine(" ") + namedAttr.name).str());
++numAttrs;
}
}
if (numOperands + numAttrs != op.getNumArgs())
PrintFatalError("op arguments must be either operands or attributes");
buildParamList(paramList, resultNames, /*includeResultTypes=*/true);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
// Push all result types to the result
if (numResults > 0) {
if (!useOperandType && !useAttrType) {
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
m.body() << " " << builderOpState;
if (result.isVariadic()) {
m.body() << "->addTypes(";
} else {
m.body() << "->types.push_back(";
}
m.body() << resultNames[i] << ");\n";
}
} else {
std::string resultType;
if (useAttrType) {
const auto &namedAttr = op.getAttribute(0);
if (namedAttr.attr.isTypeAttr()) {
resultType = formatv("{0}.getValue()", namedAttr.name);
} else {
resultType = formatv("{0}.getType()", namedAttr.name);
}
} else {
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
resultType =
formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
}
m.body() << " " << builderOpState << "->addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
m.body() << ", " << resultType;
m.body() << "});\n\n";
}
// Push all result types to the operation state
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
m.body() << " " << builderOpState << "->addTypes(" << resultNames[i]
<< ");\n";
}
}
// Push all operands to the result
for (int i = 0; i < numOperands; ++i) {
const auto &operand = op.getOperand(i);
m.body() << " " << builderOpState;
if (operand.isVariadic()) {
m.body() << "->addOperands(";
} else {
m.body() << "->operands.push_back(";
}
m.body() << getArgumentName(op, i) << ");\n";
}
void OpEmitter::genUseOperandAsResultTypeBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
// Push all attributes to the result
for (const auto &namedAttr : op.getAttributes()) {
if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) {
m.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
m.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
builderOpState, namedAttr.name);
if (emitNotNullCheck) {
m.body() << " }\n";
}
}
}
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << "->addRegion();\n";
auto numResults = op.getNumResults();
if (numResults == 0)
return;
// Push all result types to the operation state
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
std::string resultType =
formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
m.body() << " " << builderOpState << "->addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
m.body() << ", " << resultType;
m.body() << "});\n\n";
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, /*includeResultTypes=*/false);
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
genCodeForAddingArgAndRegionForBuilder(m.body());
auto numResults = op.getNumResults();
if (numResults == 0)
return;
// Push all result types to the operation state
std::string resultType;
const auto &namedAttr = op.getAttribute(0);
if (namedAttr.attr.isTypeAttr()) {
resultType = formatv("{0}.getValue()", namedAttr.name);
} else {
resultType = formatv("{0}.getType()", namedAttr.name);
}
m.body() << " " << builderOpState << "->addTypes({" << resultType;
for (int i = 1; i != numResults; ++i)
m.body() << ", " << resultType;
m.body() << "});\n\n";
}
void OpEmitter::genBuilder() {
@@ -893,6 +854,28 @@ void OpEmitter::genBuilder() {
}
}
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three builders here:
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
genSeparateParamBuilder();
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
genCollectiveParamBuilder();
// 3. one having a stand-alone prameter for each operand and attribute,
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
if (op.hasTrait("SameOperandsAndResultType"))
genUseOperandAsResultTypeBuilder();
if (op.hasTrait("FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
}
void OpEmitter::genCollectiveParamBuilder() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();
int numNonVariadicResults = numResults - numVariadicResults;
@@ -900,25 +883,6 @@ void OpEmitter::genBuilder() {
int numOperands = op.getNumOperands();
int numVariadicOperands = op.getNumVariadicOperands();
int numNonVariadicOperands = numOperands - numVariadicOperands;
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
// We generate three builders here:
// 1. one having a stand-alone parameter for each result type / operand /
// attribute, and
// 2. one having an aggregated parameter for all result types / operands /
// attributes, and
// 3. one having a stand-alone prameter for each operand and attribute,
// use the first operand's type as all result types
// to facilitate different call patterns.
// 1. Stand-alone parameters
genStandaloneParamBuilder(/*useOperandType=*/false, /*useAttrType=*/false);
// 2. Aggregated parameters
// Signature
std::string params =
std::string("Builder *, OperationState *") + builderOpState +
@@ -952,13 +916,91 @@ void OpEmitter::genBuilder() {
for (int i = 0; i < numRegions; ++i)
m.body() << " (void)" << builderOpState << "->addRegion();\n";
}
}
// 3. Deduced result types
void OpEmitter::buildParamList(std::string &paramList,
SmallVectorImpl<std::string> &resultTypeNames,
bool includeResultTypes) {
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
if (numVariadicResults == 0 && (useOperandType || useAttrType))
genStandaloneParamBuilder(useOperandType, useAttrType);
paramList = "Builder *, OperationState *";
paramList.append(builderOpState);
if (includeResultTypes) {
resultTypeNames.clear();
auto numResults = op.getNumResults();
resultTypeNames.reserve(numResults);
// Add parameters for all return types
for (int i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
std::string resultName = result.name;
if (resultName.empty())
resultName = formatv("resultType{0}", i);
paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
paramList.append(resultName);
resultTypeNames.emplace_back(std::move(resultName));
}
}
int numOperands = 0;
int numAttrs = 0;
// Add parameters for all arguments (operands and attributes).
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
// TODO(antiagainst): Support default initializer for attributes
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
paramList.append(", ");
if (attr.isOptional())
paramList.append("/*optional*/");
paramList.append(attr.getStorageType());
paramList.append(" ");
paramList.append(namedAttr.name);
++numAttrs;
}
}
if (numOperands + numAttrs != op.getNumArgs())
PrintFatalError("op arguments must be either operands or attributes");
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body) {
// Push all operands to the result
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
body << " " << builderOpState << "->addOperands(" << getArgumentName(op, i)
<< ");\n";
}
// Push all attributes to the result
for (const auto &namedAttr : op.getAttributes()) {
if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) {
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
body << formatv(" {0}->addAttribute(\"{1}\", {1});\n", builderOpState,
namedAttr.name);
if (emitNotNullCheck) {
body << " }\n";
}
}
}
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
for (int i = 0; i < numRegions; ++i)
body << " (void)" << builderOpState << "->addRegion();\n";
}
}
void OpEmitter::genCanonicalizerDecls() {