[mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded conversions. (#124439)

Using `ConvertToLLVMPatternInterface` allows to unhardcode specific
dialect conversions from passes and, more importantly, allows downstream
projects to inject their ops/types translation here by registering
corresponding interface.

Add `allowed-dialects` option so user can control which dialects can be
used to populate conversions.
This commit is contained in:
Ivan Butygin
2025-02-13 15:53:12 +01:00
committed by GitHub
parent 3e54964dcf
commit aecb764cc2
7 changed files with 114 additions and 39 deletions

View File

@@ -572,14 +572,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
"Bitwidth of the index type, 0 to use size of machine word">,
Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
"Target gpu supports redux">,
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers. "
"All memrefs must have static shape.">
"All memrefs must have static shape.">,
ListOption<"allowedDialects", "allowed-dialects", "std::string",
"Run conversion patterns of only the specified dialects">,
];
}
@@ -600,20 +602,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">,
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
"Bitwidth of the index type, 0 to use size of machine word">,
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
/*default=*/"false",
"Replace memref arguments in GPU functions with bare pointers."
"All memrefs must have static shape">,
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
"::mlir::gpu::amd::Runtime::Unknown",
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
[{::llvm::cl::values(
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
)}]>
"::mlir::gpu::amd::Runtime::Unknown",
"Runtime code will be run on (default is Unknown, can also use HIP "
"or OpenCL)",
[{::llvm::cl::values(
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
"Unknown (default)"),
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
"OpenCL"))}]>,
ListOption<"allowedDialects", "allowed-dialects", "std::string",
"Run conversion patterns of only the specified dialects">,
];
}

View File

@@ -11,19 +11,14 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -342,10 +337,15 @@ struct AssertOpToAssertfailLowering
///
/// This pass only handles device code and is not meant to be run on GPU host
/// code.
struct LowerGpuOpsToNVVMOpsPass
struct LowerGpuOpsToNVVMOpsPass final
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
Base::getDependentDialects(registry);
registerConvertToLLVMDependentDialectLoading(registry);
}
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
@@ -376,17 +376,41 @@ struct LowerGpuOpsToNVVMOpsPass
LLVMTypeConverter converter(m.getContext(), options);
configureGpuToNVVMTypeConverter(converter);
RewritePatternSet llvmPatterns(m.getContext());
LLVMConversionTarget target(getContext());
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
allowedDialects.end());
for (Dialect *dialect : getContext().getLoadedDialects()) {
// Skip math patterns as nvvm needs custom math lowering.
if (isa<math::MathDialect>(dialect))
continue;
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
// Empty `allowedDialectsSet` means all dialects are allowed.
if (!allowedDialectsSet.empty() && !allowed)
continue;
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface) {
// Error out if dialect was explicily specified but doesn't implement
// conversion interface.
if (allowed) {
m.emitError()
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialect->getNamespace();
return signalPassFailure();
}
continue;
}
iface->populateConvertToLLVMConversionPatterns(target, converter,
llvmPatterns);
}
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
if (this->hasRedux)
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
LLVMConversionTarget target(getContext());
configureGpuToNVVMConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();
@@ -397,6 +421,7 @@ struct LowerGpuOpsToNVVMOpsPass
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
target.addIllegalOp<func::FuncOp>();
target.addIllegalOp<cf::AssertOp>();
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
target.addIllegalDialect<gpu::GPUDialect>();
@@ -472,8 +497,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
// Set higher benefit, so patterns will run before generic LLVM lowering.
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
converter);
converter, /*benefit*/ 10);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(

View File

@@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
@@ -19,8 +18,8 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
@@ -28,8 +27,6 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -202,7 +199,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
//
// This pass only handles device code and is not meant to be run on GPU host
// code.
struct LowerGpuOpsToROCDLOpsPass
struct LowerGpuOpsToROCDLOpsPass final
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
LowerGpuOpsToROCDLOpsPass() = default;
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
this->runtime = runtime;
}
void getDependentDialects(DialectRegistry &registry) const override {
Base::getDependentDialects(registry);
registerConvertToLLVMDependentDialectLoading(registry);
}
void runOnOperation() override {
gpu::GPUModuleOp m = getOperation();
MLIRContext *ctx = m.getContext();
@@ -289,18 +291,36 @@ struct LowerGpuOpsToROCDLOpsPass
});
RewritePatternSet llvmPatterns(ctx);
LLVMConversionTarget target(getContext());
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
allowedDialects.end());
for (Dialect *dialect : ctx->getLoadedDialects()) {
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
// Empty `allowedDialectsSet` means all dialects are allowed.
if (!allowedDialectsSet.empty() && !allowed)
continue;
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
if (!iface) {
// Error out if dialect was explicily specified but doesn't implement
// conversion interface.
if (allowed) {
m.emitError()
<< "dialect does not implement ConvertToLLVMPatternInterface: "
<< dialect->getNamespace();
return signalPassFailure();
}
continue;
}
iface->populateConvertToLLVMConversionPatterns(target, converter,
llvmPatterns);
}
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
*maybeChipset);
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
LLVMConversionTarget target(getContext());
configureGpuToROCDLConversionLegality(target);
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
signalPassFailure();

View File

@@ -0,0 +1,10 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm='allowed-dialects=test' -verify-diagnostics
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
gpu.module @test_module_1 {
func.func @test(%0 : index) -> index {
%1 = test.increment %0 : index
func.return %1 : index
}
}

View File

@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s

View File

@@ -0,0 +1,10 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=test' -verify-diagnostics
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
gpu.module @test_module_1 {
func.func @test(%0 : index) -> index {
%1 = test.increment %0 : index
func.return %1 : index
}
}

View File

@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
// CHECK-LABEL: @test_module