diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 6d2d78d5825f..38390d801134 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -697,7 +697,7 @@ static bool isOne(mlir::Value v) { return checkIsIntegerConstant(v, 1); } template struct UndoComplexPattern : public mlir::RewritePattern { UndoComplexPattern(mlir::MLIRContext *ctx) - : mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {} + : mlir::RewritePattern("fir.insert_value", 2, ctx) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h index f66a29250aa2..eeb20c4806b9 100644 --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -30,12 +30,12 @@ namespace linalg { // or in an externally linked library. // This is a generic entry point for all LinalgOp, except for CopyOp and // IndexedGenericOp, for which omre specialized patterns are provided. -class LinalgOpToLibraryCallRewrite : public RewritePattern { +class LinalgOpToLibraryCallRewrite + : public OpInterfaceRewritePattern { public: - LinalgOpToLibraryCallRewrite() - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h index d005cc310abe..bee0d5a12800 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -60,7 +60,8 @@ void enqueue(RewritePatternSet &patternList, OptionsType options, if (!opName.empty()) patternList.add(opName, patternList.getContext(), options, m); else - patternList.add(m.addOpFilter(), options); + patternList.add(patternList.getContext(), + m.addOpFilter(), options); } /// Promotion transformation enqueues a particular stage-1 pattern for diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 71dbe9fb24cd..21e6cba9dc3c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -452,7 +452,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); struct LinalgBaseTilingPattern : public RewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgBaseTilingPattern( - LinalgTilingOptions options, + MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); // Entry point to match a specific Linalg op. @@ -644,7 +644,8 @@ struct LinalgVectorizationOptions {}; struct LinalgBaseVectorizationPattern : public RewritePattern { /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBaseVectorizationPattern(LinalgTransformationFilter filter, + LinalgBaseVectorizationPattern(MLIRContext *context, + LinalgTransformationFilter filter, PatternBenefit benefit = 1); /// Name-based constructor with an optional `filter`. LinalgBaseVectorizationPattern( @@ -663,10 +664,10 @@ struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { /// These constructors are available to anyone. /// MatchAnyOpTag-based constructor with a mandatory `filter`. LinalgVectorizationPattern( - LinalgTransformationFilter filter, + MLIRContext *context, LinalgTransformationFilter filter, LinalgVectorizationOptions options = LinalgVectorizationOptions(), PatternBenefit benefit = 1) - : LinalgBaseVectorizationPattern(filter, benefit) {} + : LinalgBaseVectorizationPattern(context, filter, benefit) {} /// Name-based constructor with an optional `filter`. LinalgVectorizationPattern( StringRef opName, MLIRContext *context, @@ -702,8 +703,8 @@ template (f.addOpFilter(), - options); + patternList.add( + patternList.getContext(), f.addOpFilter(), options); } /// Variadic helper function to insert vectorization patterns for C++ ops. @@ -737,7 +738,7 @@ struct LinalgLoweringPattern : public RewritePattern { MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter filter = LinalgTransformationFilter(), ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) - : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + : RewritePattern(OpTy::getOperationName(), benefit, context), filter(filter), loweringType(loweringType), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h index 35eb83d8f03a..b765dafbce46 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -123,7 +123,8 @@ struct UnrollVectorOptions { struct UnrollVectorPattern : public RewritePattern { using FilterConstraintType = std::function; UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + options(options) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (options.filterConstraint && failed(options.filterConstraint(op))) @@ -216,7 +217,7 @@ struct VectorTransferFullPartialRewriter : public RewritePattern { FilterConstraintType filter = [](VectorTransferOpInterface op) { return success(); }, PatternBenefit benefit = 1) - : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options), + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), filter(filter) {} /// Performs the rewrite. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index b27e1e0e4a78..ec3884e58fc3 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1516,6 +1516,13 @@ public: #endif return false; } + /// Provide `classof` support for other OpBase derived classes, such as + /// Interfaces. + template + static std::enable_if_t::value, bool> + classof(const T *op) { + return classof(const_cast(op)->getOperation()); + } /// Expose the type we are instantiated on to template machinery that may want /// to introspect traits on this operation. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 8cc97d9c02ee..cb82ec9c4714 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -142,12 +142,20 @@ public: return interfaceMap.lookup(); } + /// Returns true if this operation has the given interface registered to it. + bool hasInterface(TypeID interfaceID) const { + return interfaceMap.contains(interfaceID); + } + /// Returns true if the operation has a particular trait. template