From e152f0194fdaff632a5a9737d4ea057218871782 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 19 Aug 2019 12:12:50 -0700 Subject: [PATCH] NFC: Don't assume that all operation traits are within the 'OpTrait::' namespace. This places an unnecessary restriction that all traits are within this namespace. PiperOrigin-RevId: 264212000 --- mlir/include/mlir/IR/OpBase.td | 4 +-- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 37 +++++++++++---------- mlir/tools/mlir-tblgen/RewriterGen.cpp | 7 ++-- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 3183a762da51..eb49c237bb21 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -996,7 +996,7 @@ class OpTrait; // purpose to wrap around C++ symbol string with this class is to make // traits specified for ops in TableGen less alien and more integrated. class NativeOpTrait : OpTrait { - string trait = prop; + string trait = "OpTrait::" # prop; } // ParamNativeOpTrait corresponds to the template-parameterized traits in the @@ -1012,7 +1012,7 @@ class ParamNativeOpTrait // affects op definition generator internals, like how op builders and // operand/attribute/result getters are generated. class GenInternalOpTrait : OpTrait { - string trait = prop; + string trait = "OpTrait::" # prop; } // PredOpTrait is an op trait implemented by way of a predicate on the op. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index d5e6cf4a7716..b37126202756 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -402,10 +402,8 @@ void Class::writeDefTo(raw_ostream &os) const { OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) : Class(name), extraClassDeclaration(extraClassDeclaration) {} -// Adds the given trait to this op. Prefixes "OpTrait::" to `trait` implicitly. -void OpClass::addTrait(Twine trait) { - traits.push_back(("OpTrait::" + trait).str()); -} +// Adds the given trait to this op. +void OpClass::addTrait(Twine trait) { traits.push_back(trait.str()); } void OpClass::writeDeclTo(raw_ostream &os) const { os << "class " << className << " : public Op<" << className; @@ -649,7 +647,8 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass, const int numVariadicOperands = op.getNumVariadicOperands(); const int numNormalOperands = numOperands - numVariadicOperands; - if (numVariadicOperands > 1 && !op.hasTrait("SameVariadicOperandSize")) { + if (numVariadicOperands > 1 && + !op.hasTrait("OpTrait::SameVariadicOperandSize")) { PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " "specification over their sizes"); } @@ -712,7 +711,8 @@ void OpEmitter::genNamedResultGetters() { // If we have more than one variadic results, we need more complicated logic // to calculate the value range for each result. - if (numVariadicResults > 1 && !op.hasTrait("SameVariadicResultSize")) { + if (numVariadicResults > 1 && + !op.hasTrait("OpTrait::SameVariadicResultSize")) { PrintFatalError(op.getLoc(), "op has multiple variadic results but no " "specification over their sizes"); } @@ -868,9 +868,9 @@ void OpEmitter::genBuilder() { // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariadicResults() == 0) { - if (op.hasTrait("SameOperandsAndResultType")) + if (op.hasTrait("OpTrait::SameOperandsAndResultType")) genUseOperandAsResultTypeBuilder(); - if (op.hasTrait("FirstAttrDerivedResultType")) + if (op.hasTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } } @@ -1224,19 +1224,20 @@ void OpEmitter::genTraits() { // Add return size trait. if (numVariadicResults != 0) { if (numResults == numVariadicResults) - opClass.addTrait("VariadicResults"); + opClass.addTrait("OpTrait::VariadicResults"); else - opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl"); + opClass.addTrait("OpTrait::AtLeastNResults<" + Twine(numResults - 1) + + ">::Impl"); } else { switch (numResults) { case 0: - opClass.addTrait("ZeroResult"); + opClass.addTrait("OpTrait::ZeroResult"); break; case 1: - opClass.addTrait("OneResult"); + opClass.addTrait("OpTrait::OneResult"); break; default: - opClass.addTrait("NResults<" + Twine(numResults) + ">::Impl"); + opClass.addTrait("OpTrait::NResults<" + Twine(numResults) + ">::Impl"); break; } } @@ -1253,20 +1254,20 @@ void OpEmitter::genTraits() { // Add operand size trait. if (numVariadicOperands != 0) { if (numOperands == numVariadicOperands) - opClass.addTrait("VariadicOperands"); + opClass.addTrait("OpTrait::VariadicOperands"); else - opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) + + opClass.addTrait("OpTrait::AtLeastNOperands<" + Twine(numOperands - 1) + ">::Impl"); } else { switch (numOperands) { case 0: - opClass.addTrait("ZeroOperands"); + opClass.addTrait("OpTrait::ZeroOperands"); break; case 1: - opClass.addTrait("OneOperand"); + opClass.addTrait("OpTrait::OneOperand"); break; default: - opClass.addTrait("NOperands<" + Twine(numOperands) + ">::Impl"); + opClass.addTrait("OpTrait::NOperands<" + Twine(numOperands) + ">::Impl"); break; } } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 3487eda545fd..0054c38dd239 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -649,9 +649,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // special cases listed below, we need to supply types for all results // when building an op. bool isSameOperandsAndResultType = - resultOp.hasTrait("SameOperandsAndResultType"); - bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult"); - bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType"); + resultOp.hasTrait("OpTrait::SameOperandsAndResultType"); + bool isBroadcastable = + resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult"); + bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType"); bool usePartialResults = valuePackName != resultValue; if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||