mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 02:00:03 +08:00
[mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D144219
This commit is contained in:
@@ -42,8 +42,11 @@ namespace gpu {
|
||||
/// supported. Dynamic block dim sizes are currently not supported.
|
||||
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
|
||||
RewriterBase &rewriter, Operation *target,
|
||||
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
|
||||
std::optional<TransformOpInterface> transformOp,
|
||||
const SmallVectorImpl<int64_t> &blockDim,
|
||||
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
|
||||
SmallVectorImpl<Value> &)>
|
||||
threadIdGenerator,
|
||||
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
|
||||
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
|
||||
|
||||
/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is
|
||||
|
||||
@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
|
||||
|
||||
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
|
||||
RewriterBase &rewriter, Operation *target,
|
||||
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
|
||||
std::optional<TransformOpInterface> transformOp,
|
||||
const SmallVectorImpl<int64_t> &blockDim,
|
||||
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
|
||||
SmallVectorImpl<Value> &)>
|
||||
threadIdGenerator,
|
||||
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
|
||||
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
|
||||
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
|
||||
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
|
||||
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
|
||||
foreachThreadOp.getMapping(), transformOp);
|
||||
if (diag.succeeded()) {
|
||||
rewriter.setInsertionPoint(foreachThreadOp);
|
||||
IndexType indexType = rewriter.getIndexType();
|
||||
SmallVector<Value> threadOps{
|
||||
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
|
||||
Dimension::x),
|
||||
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
|
||||
Dimension::y),
|
||||
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
|
||||
Dimension::z)};
|
||||
SmallVector<Value> threadOps;
|
||||
threadIdGenerator(rewriter, foreachThreadOp, threadOps);
|
||||
diag = rewriteOneForeachThreadToGpuThreads(
|
||||
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
|
||||
transformOp, threadMappingAttributes);
|
||||
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
|
||||
GPUThreadMappingAttr::get(ctx, Threads::DimX),
|
||||
GPUThreadMappingAttr::get(ctx, Threads::DimY),
|
||||
GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
|
||||
|
||||
auto threadIdGenerator = [](RewriterBase &rewriter,
|
||||
scf::ForeachThreadOp foreachThreadOp,
|
||||
SmallVectorImpl<Value> &threadIds) {
|
||||
IndexType indexType = rewriter.getIndexType();
|
||||
threadIds.assign({rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
|
||||
indexType, Dimension::x),
|
||||
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
|
||||
indexType, Dimension::y),
|
||||
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
|
||||
indexType, Dimension::z)});
|
||||
};
|
||||
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
|
||||
rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
|
||||
threadMappingAttributes);
|
||||
rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
|
||||
transformOp, threadMappingAttributes);
|
||||
|
||||
if (diag.succeeded()) {
|
||||
diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,
|
||||
|
||||
Reference in New Issue
Block a user