mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir][gpu] GPUToROCDL/NVVM: use generic llvm conversion interface instead of hardcoded conversions. (#124439)
Using `ConvertToLLVMPatternInterface` allows to unhardcode specific dialect conversions from passes and, more importantly, allows downstream projects to inject their ops/types translation here by registering corresponding interface. Add `allowed-dialects` option so user can control which dialects can be used to populate conversions.
This commit is contained in:
@@ -572,14 +572,16 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
|
||||
];
|
||||
let options = [
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
|
||||
"Bitwidth of the index type, 0 to use size of machine word">,
|
||||
Option<"hasRedux", "has-redux", "bool", /*default=*/"false",
|
||||
"Target gpu supports redux">,
|
||||
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
|
||||
/*default=*/"false",
|
||||
"Replace memref arguments in GPU functions with bare pointers. "
|
||||
"All memrefs must have static shape.">
|
||||
"All memrefs must have static shape.">,
|
||||
ListOption<"allowedDialects", "allowed-dialects", "std::string",
|
||||
"Run conversion patterns of only the specified dialects">,
|
||||
];
|
||||
}
|
||||
|
||||
@@ -600,20 +602,24 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
|
||||
/*default=*/"\"gfx000\"",
|
||||
"Chipset that these operations will run on">,
|
||||
Option<"indexBitwidth", "index-bitwidth", "unsigned",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
|
||||
/*default=kDeriveIndexBitwidthFromDataLayout*/ "0",
|
||||
"Bitwidth of the index type, 0 to use size of machine word">,
|
||||
Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool",
|
||||
/*default=*/"false",
|
||||
"Replace memref arguments in GPU functions with bare pointers."
|
||||
"All memrefs must have static shape">,
|
||||
Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime",
|
||||
"::mlir::gpu::amd::Runtime::Unknown",
|
||||
"Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)",
|
||||
[{::llvm::cl::values(
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"),
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL")
|
||||
)}]>
|
||||
"::mlir::gpu::amd::Runtime::Unknown",
|
||||
"Runtime code will be run on (default is Unknown, can also use HIP "
|
||||
"or OpenCL)",
|
||||
[{::llvm::cl::values(
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown",
|
||||
"Unknown (default)"),
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
|
||||
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
|
||||
"OpenCL"))}]>,
|
||||
ListOption<"allowedDialects", "allowed-dialects", "std::string",
|
||||
"Run conversion patterns of only the specified dialects">,
|
||||
];
|
||||
}
|
||||
|
||||
|
||||
@@ -11,19 +11,14 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
@@ -342,10 +337,15 @@ struct AssertOpToAssertfailLowering
|
||||
///
|
||||
/// This pass only handles device code and is not meant to be run on GPU host
|
||||
/// code.
|
||||
struct LowerGpuOpsToNVVMOpsPass
|
||||
struct LowerGpuOpsToNVVMOpsPass final
|
||||
: public impl::ConvertGpuOpsToNVVMOpsBase<LowerGpuOpsToNVVMOpsPass> {
|
||||
using Base::Base;
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
Base::getDependentDialects(registry);
|
||||
registerConvertToLLVMDependentDialectLoading(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUModuleOp m = getOperation();
|
||||
|
||||
@@ -376,17 +376,41 @@ struct LowerGpuOpsToNVVMOpsPass
|
||||
LLVMTypeConverter converter(m.getContext(), options);
|
||||
configureGpuToNVVMTypeConverter(converter);
|
||||
RewritePatternSet llvmPatterns(m.getContext());
|
||||
LLVMConversionTarget target(getContext());
|
||||
|
||||
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
|
||||
allowedDialects.end());
|
||||
for (Dialect *dialect : getContext().getLoadedDialects()) {
|
||||
// Skip math patterns as nvvm needs custom math lowering.
|
||||
if (isa<math::MathDialect>(dialect))
|
||||
continue;
|
||||
|
||||
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
|
||||
// Empty `allowedDialectsSet` means all dialects are allowed.
|
||||
if (!allowedDialectsSet.empty() && !allowed)
|
||||
continue;
|
||||
|
||||
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
|
||||
if (!iface) {
|
||||
// Error out if dialect was explicily specified but doesn't implement
|
||||
// conversion interface.
|
||||
if (allowed) {
|
||||
m.emitError()
|
||||
<< "dialect does not implement ConvertToLLVMPatternInterface: "
|
||||
<< dialect->getNamespace();
|
||||
return signalPassFailure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
iface->populateConvertToLLVMConversionPatterns(target, converter,
|
||||
llvmPatterns);
|
||||
}
|
||||
|
||||
arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateGpuToNVVMConversionPatterns(converter, llvmPatterns);
|
||||
populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
if (this->hasRedux)
|
||||
populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns);
|
||||
LLVMConversionTarget target(getContext());
|
||||
configureGpuToNVVMConversionLegality(target);
|
||||
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
|
||||
signalPassFailure();
|
||||
@@ -397,6 +421,7 @@ struct LowerGpuOpsToNVVMOpsPass
|
||||
|
||||
void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) {
|
||||
target.addIllegalOp<func::FuncOp>();
|
||||
target.addIllegalOp<cf::AssertOp>();
|
||||
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
||||
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
|
||||
target.addIllegalDialect<gpu::GPUDialect>();
|
||||
@@ -472,8 +497,10 @@ void mlir::populateGpuToNVVMConversionPatterns(
|
||||
using gpu::index_lowering::IndexKind;
|
||||
using gpu::index_lowering::IntrType;
|
||||
populateWithGenerated(patterns);
|
||||
|
||||
// Set higher benefit, so patterns will run before generic LLVM lowering.
|
||||
patterns.add<GPUPrintfOpToVPrintfLowering, AssertOpToAssertfailLowering>(
|
||||
converter);
|
||||
converter, /*benefit*/ 10);
|
||||
patterns.add<
|
||||
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
|
||||
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -19,8 +18,8 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
||||
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
|
||||
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
||||
@@ -28,8 +27,6 @@
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
@@ -202,7 +199,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
|
||||
//
|
||||
// This pass only handles device code and is not meant to be run on GPU host
|
||||
// code.
|
||||
struct LowerGpuOpsToROCDLOpsPass
|
||||
struct LowerGpuOpsToROCDLOpsPass final
|
||||
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
|
||||
LowerGpuOpsToROCDLOpsPass() = default;
|
||||
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
|
||||
@@ -218,6 +215,11 @@ struct LowerGpuOpsToROCDLOpsPass
|
||||
this->runtime = runtime;
|
||||
}
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
Base::getDependentDialects(registry);
|
||||
registerConvertToLLVMDependentDialectLoading(registry);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
gpu::GPUModuleOp m = getOperation();
|
||||
MLIRContext *ctx = m.getContext();
|
||||
@@ -289,18 +291,36 @@ struct LowerGpuOpsToROCDLOpsPass
|
||||
});
|
||||
|
||||
RewritePatternSet llvmPatterns(ctx);
|
||||
LLVMConversionTarget target(getContext());
|
||||
|
||||
llvm::SmallDenseSet<StringRef> allowedDialectsSet(allowedDialects.begin(),
|
||||
allowedDialects.end());
|
||||
for (Dialect *dialect : ctx->getLoadedDialects()) {
|
||||
bool allowed = allowedDialectsSet.contains(dialect->getNamespace());
|
||||
// Empty `allowedDialectsSet` means all dialects are allowed.
|
||||
if (!allowedDialectsSet.empty() && !allowed)
|
||||
continue;
|
||||
|
||||
auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
|
||||
if (!iface) {
|
||||
// Error out if dialect was explicily specified but doesn't implement
|
||||
// conversion interface.
|
||||
if (allowed) {
|
||||
m.emitError()
|
||||
<< "dialect does not implement ConvertToLLVMPatternInterface: "
|
||||
<< dialect->getNamespace();
|
||||
return signalPassFailure();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
iface->populateConvertToLLVMConversionPatterns(target, converter,
|
||||
llvmPatterns);
|
||||
}
|
||||
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
|
||||
*maybeChipset);
|
||||
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
cf::populateAssertToLLVMConversionPattern(converter, llvmPatterns);
|
||||
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
|
||||
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
|
||||
LLVMConversionTarget target(getContext());
|
||||
configureGpuToROCDLConversionLegality(target);
|
||||
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
|
||||
signalPassFailure();
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm='allowed-dialects=test' -verify-diagnostics
|
||||
|
||||
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
|
||||
gpu.module @test_module_1 {
|
||||
func.func @test(%0 : index) -> index {
|
||||
%1 = test.increment %0 : index
|
||||
func.return %1 : index
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 allowed-dialects=func,arith,cf' -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
|
||||
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=test' -verify-diagnostics
|
||||
|
||||
// expected-error @+1 {{dialect does not implement ConvertToLLVMPatternInterface: test}}
|
||||
gpu.module @test_module_1 {
|
||||
func.func @test(%0 : index) -> index {
|
||||
%1 = test.increment %0 : index
|
||||
func.return %1 : index
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl='allowed-dialects=func,arith,math' -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s
|
||||
|
||||
// CHECK-LABEL: @test_module
|
||||
|
||||
Reference in New Issue
Block a user