[mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D144219
This commit is contained in:
Thomas Raoux
2023-02-17 00:15:12 +00:00
parent e3a88a41af
commit 0eabb884ab
2 changed files with 25 additions and 15 deletions

View File

@@ -42,8 +42,11 @@ namespace gpu {
/// supported. Dynamic block dim sizes are currently not supported.
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
RewriterBase &rewriter, Operation *target,
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
std::optional<TransformOpInterface> transformOp,
const SmallVectorImpl<int64_t> &blockDim,
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
SmallVectorImpl<Value> &)>
threadIdGenerator,
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is

View File

@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
RewriterBase &rewriter, Operation *target,
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
std::optional<TransformOpInterface> transformOp,
const SmallVectorImpl<int64_t> &blockDim,
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
SmallVectorImpl<Value> &)>
threadIdGenerator,
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
foreachThreadOp.getMapping(), transformOp);
if (diag.succeeded()) {
rewriter.setInsertionPoint(foreachThreadOp);
IndexType indexType = rewriter.getIndexType();
SmallVector<Value> threadOps{
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::x),
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::y),
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
Dimension::z)};
SmallVector<Value> threadOps;
threadIdGenerator(rewriter, foreachThreadOp, threadOps);
diag = rewriteOneForeachThreadToGpuThreads(
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
transformOp, threadMappingAttributes);
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
GPUThreadMappingAttr::get(ctx, Threads::DimX),
GPUThreadMappingAttr::get(ctx, Threads::DimY),
GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
auto threadIdGenerator = [](RewriterBase &rewriter,
scf::ForeachThreadOp foreachThreadOp,
SmallVectorImpl<Value> &threadIds) {
IndexType indexType = rewriter.getIndexType();
threadIds.assign({rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
indexType, Dimension::x),
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
indexType, Dimension::y),
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
indexType, Dimension::z)});
};
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
threadMappingAttributes);
rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
transformOp, threadMappingAttributes);
if (diag.succeeded()) {
diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,