mirror of
https://github.com/intel/llvm.git
synced 2026-01-14 11:57:39 +08:00
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
438 lines
17 KiB
C++
438 lines
17 KiB
C++
//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
|
|
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
|
|
|
|
#include <optional>
|
|
|
|
#include "llvm/Support/DebugLog.h"
|
|
#define DEBUG_TYPE "xegpu-transforms"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
/// Assuming that `ofr` is an index attr or a param of index type
|
|
/// or a transform dialect handle mapped to exactly one op
|
|
/// with one index result, get that value and cast it to int type.
|
|
static DiagnosedSilenceableFailure convertMixedValuesToInt(
|
|
transform::TransformState &state, TransformOpInterface transformOp,
|
|
SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
|
|
for (OpFoldResult ofr : ofrs) {
|
|
// Attribute case.
|
|
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
|
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
|
|
result.push_back(intAttr.getInt());
|
|
continue;
|
|
}
|
|
return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
|
|
}
|
|
|
|
// Transform param case.
|
|
Value transformValue = cast<Value>(ofr);
|
|
if (isa<TransformParamTypeInterface>(transformValue.getType())) {
|
|
ArrayRef<Attribute> params = state.getParams(transformValue);
|
|
if (params.size() != 1)
|
|
return transformOp.emitDefiniteFailure()
|
|
<< "requires exactly one parameter associated";
|
|
result.push_back(
|
|
cast<IntegerAttr>(params.front()).getValue().getSExtValue());
|
|
continue;
|
|
}
|
|
|
|
// Payload value case.
|
|
auto payloadOps = state.getPayloadOps(transformValue);
|
|
if (!llvm::hasSingleElement(payloadOps)) {
|
|
DiagnosedSilenceableFailure diag =
|
|
transformOp.emitSilenceableError()
|
|
<< "handle must be mapped to exactly one payload op";
|
|
diag.attachNote(transformValue.getLoc())
|
|
<< "mapped to " << llvm::range_size(payloadOps) << " payload ops";
|
|
return diag;
|
|
}
|
|
|
|
Operation *op = *payloadOps.begin();
|
|
if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
|
|
DiagnosedSilenceableFailure diag =
|
|
transformOp.emitSilenceableError()
|
|
<< "payload op must have exactly 1 index result";
|
|
diag.attachNote(op->getLoc())
|
|
<< "has " << op->getNumResults() << " results";
|
|
return diag;
|
|
}
|
|
|
|
IntegerAttr intAttr;
|
|
if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
|
|
return transformOp.emitSilenceableError()
|
|
<< "requires param or handle to be the result of a constant like "
|
|
"op";
|
|
|
|
result.push_back(intAttr.getInt());
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
/// Find producer operation of type T for the given value.
|
|
/// It's assumed that producer ops are chained through their first operand.
|
|
/// Producer chain is traced trough loop block arguments (init values).
|
|
template <typename T>
|
|
static std::optional<T> findProducerOfType(Value val) {
|
|
Value currentValue = val;
|
|
if (!currentValue.getDefiningOp()) {
|
|
// Value may be a block argument initialized outside a loop.
|
|
if (val.getNumUses() == 0) {
|
|
LDBG() << "Failed to find producer op, value has no uses.";
|
|
return std::nullopt;
|
|
}
|
|
auto userOp = val.getUsers().begin();
|
|
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
|
|
if (!parentLoop) {
|
|
LDBG() << "Failed to find producer op, not in a loop.";
|
|
return std::nullopt;
|
|
}
|
|
int64_t iterArgIdx;
|
|
if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
|
|
auto numInductionVars = parentLoop.getLoopInductionVars()->size();
|
|
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
|
|
currentValue = parentLoop.getInits()[iterArgIdx];
|
|
} else {
|
|
LDBG() << "Failed to find producer op, value not in init values.";
|
|
return std::nullopt;
|
|
}
|
|
}
|
|
Operation *producerOp = currentValue.getDefiningOp();
|
|
|
|
if (auto matchingOp = dyn_cast<T>(producerOp))
|
|
return matchingOp;
|
|
|
|
if (producerOp->getNumOperands() == 0)
|
|
return std::nullopt;
|
|
|
|
return findProducerOfType<T>(producerOp->getOperand(0));
|
|
}
|
|
|
|
/// Create a layout attribute from the given parameters.
|
|
static xegpu::LayoutAttr
|
|
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
|
|
ArrayRef<int32_t> sgData,
|
|
std::optional<ArrayRef<int32_t>> instData) {
|
|
return xegpu::LayoutAttr::get(
|
|
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
|
|
DenseI32ArrayAttr::get(ctx, sgData),
|
|
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
|
|
/*lane_layout=*/nullptr,
|
|
/*lane_data=*/nullptr,
|
|
/*order=*/nullptr);
|
|
}
|
|
|
|
/// Generate `xegpu::LayoutAttr` from op mixed layout values.
|
|
DiagnosedSilenceableFailure
|
|
getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
|
|
TransformOpInterface transformOp,
|
|
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
|
|
ArrayRef<::mlir::OpFoldResult> mixedSgData,
|
|
ArrayRef<::mlir::OpFoldResult> mixedInstData,
|
|
xegpu::LayoutAttr &layoutAttr) {
|
|
SmallVector<int32_t> sgLayout, sgData, instData;
|
|
auto status =
|
|
convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
|
|
if (!status.succeeded())
|
|
return status;
|
|
|
|
status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
|
|
if (!status.succeeded())
|
|
return status;
|
|
|
|
status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
|
|
if (!status.succeeded())
|
|
return status;
|
|
auto maybeInstData = instData.empty()
|
|
? std::nullopt
|
|
: std::optional<ArrayRef<int32_t>>(instData);
|
|
|
|
layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
/// Replace xegpu.create_nd_desc op with a new one with the given layout.
|
|
static xegpu::CreateNdDescOp
|
|
setDescLayout(transform::TransformRewriter &rewriter,
|
|
xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
|
|
assert(descOp.getMixedOffsets().size() == 0 &&
|
|
"create desc op with offsets is not supported");
|
|
auto oldTensorDesc = descOp.getType();
|
|
auto descType = xegpu::TensorDescType::get(
|
|
oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
|
|
/*array_length=*/oldTensorDesc.getArrayLength(),
|
|
/*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
|
|
/*memory_space=*/oldTensorDesc.getMemorySpace(),
|
|
/*layout=*/layout);
|
|
|
|
rewriter.setInsertionPointAfter(descOp);
|
|
auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
|
|
descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
|
|
descOp.getMixedStrides());
|
|
return newDescOp;
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto targetValues = state.getPayloadValues(getTarget());
|
|
if (!llvm::hasSingleElement(targetValues)) {
|
|
return emitDefiniteFailure()
|
|
<< "requires exactly one target value handle (got "
|
|
<< llvm::range_size(targetValues) << ")";
|
|
}
|
|
|
|
auto maybeDescOp =
|
|
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
|
|
if (!maybeDescOp) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "Could not find a matching descriptor op when walking the "
|
|
"producer chain of the first operand.";
|
|
}
|
|
|
|
results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::SetDescLayoutOp::build(OpBuilder &builder,
|
|
OperationState &result, Value target,
|
|
ArrayRef<OpFoldResult> mixedSgLayout,
|
|
ArrayRef<OpFoldResult> mixedSgData,
|
|
ArrayRef<OpFoldResult> mixedInstData) {
|
|
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
|
|
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
|
|
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
|
|
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
|
|
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
|
|
build(builder, result, target.getType(),
|
|
/*target=*/target,
|
|
/*sg_layout=*/dynamicSgLayout,
|
|
/*sg_data=*/dynamicSgData,
|
|
/*inst_data=*/dynamicInstData,
|
|
/*static_sg_layout=*/staticSgLayout,
|
|
/*static_sg_data=*/staticSgData,
|
|
/*static_inst_data=*/staticInstData);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto targetOps = state.getPayloadOps(getTarget());
|
|
if (!llvm::hasSingleElement(targetOps)) {
|
|
return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
|
|
<< llvm::range_size(targetOps) << ")";
|
|
}
|
|
Operation *target = *targetOps.begin();
|
|
|
|
xegpu::LayoutAttr layoutAttr = nullptr;
|
|
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
|
|
getMixedSgLayout(), getMixedSgData(),
|
|
getMixedInstData(), layoutAttr);
|
|
if (!status.succeeded())
|
|
return status;
|
|
|
|
// For now only create_nd_desc op is supported.
|
|
auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
|
|
if (!descOp) {
|
|
auto diag = emitSilenceableFailure(getLoc())
|
|
<< "Expected a xegpu.create_nd_desc op, but got: "
|
|
<< target->getName();
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
|
|
// Set layout attr in desc op's return type. Replaces old desc op.
|
|
auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
|
|
|
|
// Map result handles.
|
|
results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::SetDescLayoutOp::getEffects(
|
|
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
consumesHandle(getTargetMutable(), effects);
|
|
onlyReadsHandle(getSgLayoutMutable(), effects);
|
|
onlyReadsHandle(getSgDataMutable(), effects);
|
|
onlyReadsHandle(getInstDataMutable(), effects);
|
|
producesHandle(getOperation()->getOpResults(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
void transform::SetOpLayoutAttrOp::build(
|
|
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
|
|
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
|
|
ArrayRef<OpFoldResult> mixedInstData, bool result) {
|
|
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
|
|
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
|
|
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
|
|
dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
|
|
dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
|
|
build(builder, ostate, target.getType(),
|
|
/*target=*/target,
|
|
/*index=*/index,
|
|
/*sg_layout=*/dynamicSgLayout,
|
|
/*sg_data=*/dynamicSgData,
|
|
/*inst_data=*/dynamicInstData,
|
|
/*static_sg_layout=*/staticSgLayout,
|
|
/*static_sg_data=*/staticSgData,
|
|
/*static_inst_data=*/staticInstData,
|
|
/*result=*/result);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto targetOps = state.getPayloadOps(getTarget());
|
|
if (!llvm::hasSingleElement(targetOps)) {
|
|
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
|
|
<< llvm::range_size(targetOps) << ")";
|
|
}
|
|
Operation *target = *targetOps.begin();
|
|
|
|
bool resultTarget = getResult();
|
|
|
|
int64_t index = getIndex();
|
|
if (resultTarget && index >= target->getNumResults()) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "Index exceeds the number of op results";
|
|
}
|
|
if (!resultTarget && index >= target->getNumOperands()) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "Index exceeds the number of op operands";
|
|
}
|
|
|
|
xegpu::LayoutAttr layoutAttr = nullptr;
|
|
auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
|
|
getMixedSgLayout(), getMixedSgData(),
|
|
getMixedInstData(), layoutAttr);
|
|
if (!status.succeeded())
|
|
return status;
|
|
|
|
// Set layout attribute for the op result or operand
|
|
if (resultTarget)
|
|
xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
|
|
else
|
|
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::SetOpLayoutAttrOp::getEffects(
|
|
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
onlyReadsHandle(getTargetMutable(), effects);
|
|
onlyReadsHandle(getSgLayoutMutable(), effects);
|
|
onlyReadsHandle(getSgDataMutable(), effects);
|
|
onlyReadsHandle(getInstDataMutable(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
void transform::SetGPULaunchThreadsOp::build(
|
|
OpBuilder &builder, OperationState &ostate, Value target,
|
|
ArrayRef<OpFoldResult> mixedThreads) {
|
|
SmallVector<int64_t> staticThreads;
|
|
SmallVector<Value> dynamicThreads;
|
|
dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
|
|
build(builder, ostate, target.getType(),
|
|
/*target=*/target,
|
|
/*threads=*/dynamicThreads,
|
|
/*static_threads=*/staticThreads);
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
|
|
transform::TransformResults &results,
|
|
transform::TransformState &state) {
|
|
auto targetOps = state.getPayloadOps(getTarget());
|
|
if (!llvm::hasSingleElement(targetOps)) {
|
|
return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
|
|
<< llvm::range_size(targetOps) << ")";
|
|
}
|
|
Operation *target = *targetOps.begin();
|
|
|
|
auto launchOp = dyn_cast<gpu::LaunchOp>(target);
|
|
if (!launchOp) {
|
|
auto diag = emitSilenceableFailure(getLoc())
|
|
<< "Expected a gpu.launch op, but got: " << target->getName();
|
|
diag.attachNote(target->getLoc()) << "target op";
|
|
return diag;
|
|
}
|
|
|
|
SmallVector<int32_t> threads;
|
|
DiagnosedSilenceableFailure status =
|
|
convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
|
|
if (!status.succeeded())
|
|
return status;
|
|
|
|
if (threads.size() != 3) {
|
|
return emitSilenceableFailure(getLoc())
|
|
<< "Expected threads argument to consist of three values (got "
|
|
<< threads.size() << ")";
|
|
}
|
|
|
|
rewriter.setInsertionPoint(launchOp);
|
|
auto createConstValue = [&](int value) {
|
|
return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
|
|
};
|
|
|
|
// Replace threads in-place.
|
|
launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
|
|
launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
|
|
launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
|
|
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
void transform::SetGPULaunchThreadsOp::getEffects(
|
|
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
|
onlyReadsHandle(getTargetMutable(), effects);
|
|
onlyReadsHandle(getThreadsMutable(), effects);
|
|
modifiesPayload(effects);
|
|
}
|
|
|
|
namespace {
|
|
class XeGPUTransformDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
XeGPUTransformDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init();
|
|
};
|
|
|
|
void XeGPUTransformDialectExtension::init() {
|
|
declareGeneratedDialect<scf::SCFDialect>();
|
|
declareGeneratedDialect<arith::ArithDialect>();
|
|
declareGeneratedDialect<xegpu::XeGPUDialect>();
|
|
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
|
|
>();
|
|
}
|
|
} // namespace
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
|
|
|
|
void mlir::xegpu::registerTransformDialectExtension(DialectRegistry ®istry) {
|
|
registry.addExtensions<XeGPUTransformDialectExtension>();
|
|
}
|