From b96f86daaf8420cd61a0459de14f196c8eca871b Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 11 Dec 2019 07:13:54 -0800 Subject: [PATCH] Add a function to get lowering patterns from GPU to NVVM. This enables combining the patterns with other patterns into larger lowerings. PiperOrigin-RevId: 284979271 --- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 4 +++ .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 31 +++++++++++-------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 42a9faf40e96..635d4366e834 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -26,6 +26,10 @@ class OwningRewritePatternList; class ModuleOp; template class OpPassBase; +/// Collect a set of patterns to convert from the GPU dialect to NVVM. +void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. std::unique_ptr> createLowerGpuOpsToNVVMOpsPass(); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 5c772cc85f2f..c4cc4efd31f8 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -605,19 +605,7 @@ public: OwningRewritePatternList patterns; LLVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); - populateWithGenerated(&getContext(), &patterns); - patterns.insert< - GPUIndexIntrinsicOpLowering, - GPUIndexIntrinsicOpLowering, - GPUIndexIntrinsicOpLowering, - GPUIndexIntrinsicOpLowering, - GPUAllReduceOpLowering, FuncOpLowering>(converter); - patterns.insert>(converter, "__nv_expf", - "__nv_exp"); + populateGpuToNVVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); target.addIllegalDialect(); target.addIllegalOp(); @@ -633,6 +621,23 @@ public: } // anonymous namespace +void mlir::populateGpuToNVVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateWithGenerated(converter.getDialect()->getContext(), &patterns); + patterns + .insert, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUIndexIntrinsicOpLowering, + GPUAllReduceOpLowering, FuncOpLowering>(converter); + patterns.insert>(converter, "__nv_expf", + "__nv_exp"); +} + std::unique_ptr> mlir::createLowerGpuOpsToNVVMOpsPass() { return std::make_unique(); }