[mlir][linalg] Add support for inlined const to isaFillOpInterface (#144870)

This commit is contained in:
Shay Kleiman
2025-06-23 22:53:41 +03:00
committed by GitHub
parent 653d0d0073
commit 5f74d9bb62
4 changed files with 68 additions and 3 deletions

View File

@@ -142,6 +142,9 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
/// Supports two patterns:
/// 1. External: linalg.generic ins(%scalar) outs(%tensor) { yield %scalar }
/// 2. Inlined: linalg.generic outs(%tensor) { yield %constant }
/// Returns the scalar fill value if true.
std::optional<Value> isaFillOpInterface(GenericOp genericOp);

View File

@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
/// Detects if a linalg.generic operation represents a fill with an inlined
/// constant. If so, returns the constant value. Otherwise, returns
/// std::nullopt.
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
op.getNumDpsInputs() != 0)
return std::nullopt;
// Init should not be referenced.
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
return std::nullopt;
Block *body = op.getBody();
if (body->getOperations().size() != 1)
return std::nullopt;
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1)
return std::nullopt;
Value yieldOperand = yieldOp->getOperand(0);
if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
!yieldOperand.getDefiningOp<complex::ConstantOp>())
return std::nullopt;
return yieldOperand;
}
/// Detects if a linalg.generic operation represents an external scalar input.
/// If so, returns the constant value. Otherwise, returns std::nullopt.
static std::optional<Value> isaExternalFillOp(GenericOp op) {
// Structural.
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
!op.isSingleYieldOp())
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
return value->get();
}
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
if (auto fillVal = isaInlinedFillOp(op))
return fillVal;
return isaExternalFillOp(op);
}
//===----------------------------------------------------------------------===//
// BroadcastOpInterface implementation
//===----------------------------------------------------------------------===//

View File

@@ -267,9 +267,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
// Fill
if (isaFillOpInterface(genericOp)) {
if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
// Always use the detected fill value, regardless of pattern
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
genericOp, *fillValue, genericOp.getDpsInits()[0]);
return namedOp;
}

View File

@@ -154,3 +154,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// -----
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
^bb0(%out: f32):
linalg.yield %cst : f32
} -> tensor<7x7xf32>
return %0 : tensor<7x7xf32>
}
// CHECK-LABEL: linalg_generic_inlined_constant_fill
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}