mirror of
https://github.com/intel/llvm.git
synced 2026-01-21 03:21:40 +08:00
[mlir][linalg] Add support for inlined const to isaFillOpInterface (#144870)
This commit is contained in:
@@ -142,6 +142,9 @@ bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);
|
||||
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);
|
||||
|
||||
/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
|
||||
/// Supports two patterns:
|
||||
/// 1. External: linalg.generic ins(%scalar) outs(%tensor) { yield %scalar }
|
||||
/// 2. Inlined: linalg.generic outs(%tensor) { yield %constant }
|
||||
/// Returns the scalar fill value if true.
|
||||
std::optional<Value> isaFillOpInterface(GenericOp genericOp);
|
||||
|
||||
|
||||
@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FillOpInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
|
||||
/// Detects if a linalg.generic operation represents a fill with an inlined
|
||||
/// constant. If so, returns the constant value. Otherwise, returns
|
||||
/// std::nullopt.
|
||||
static std::optional<Value> isaInlinedFillOp(GenericOp op) {
|
||||
if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
|
||||
op.getNumDpsInputs() != 0)
|
||||
return std::nullopt;
|
||||
|
||||
// Init should not be referenced.
|
||||
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
|
||||
return std::nullopt;
|
||||
|
||||
Block *body = op.getBody();
|
||||
if (body->getOperations().size() != 1)
|
||||
return std::nullopt;
|
||||
|
||||
auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
|
||||
if (!yieldOp || yieldOp.getNumOperands() != 1)
|
||||
return std::nullopt;
|
||||
|
||||
Value yieldOperand = yieldOp->getOperand(0);
|
||||
if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
|
||||
!yieldOperand.getDefiningOp<complex::ConstantOp>())
|
||||
return std::nullopt;
|
||||
|
||||
return yieldOperand;
|
||||
}
|
||||
|
||||
/// Detects if a linalg.generic operation represents an external scalar input.
|
||||
/// If so, returns the constant value. Otherwise, returns std::nullopt.
|
||||
static std::optional<Value> isaExternalFillOp(GenericOp op) {
|
||||
// Structural.
|
||||
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
|
||||
!op.isSingleYieldOp())
|
||||
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
|
||||
return value->get();
|
||||
}
|
||||
|
||||
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
|
||||
if (auto fillVal = isaInlinedFillOp(op))
|
||||
return fillVal;
|
||||
return isaExternalFillOp(op);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastOpInterface implementation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -267,9 +267,10 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
|
||||
}
|
||||
|
||||
// Fill
|
||||
if (isaFillOpInterface(genericOp)) {
|
||||
if (std::optional<Value> fillValue = isaFillOpInterface(genericOp)) {
|
||||
// Always use the detected fill value, regardless of pattern
|
||||
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
|
||||
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
|
||||
genericOp, *fillValue, genericOp.getDpsInits()[0]);
|
||||
return namedOp;
|
||||
}
|
||||
|
||||
|
||||
@@ -154,3 +154,28 @@ module attributes {transform.with_named_sequence} {
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @linalg_generic_inlined_constant_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel"]} outs(%arg0 : tensor<7x7xf32>) {
|
||||
^bb0(%out: f32):
|
||||
linalg.yield %cst : f32
|
||||
} -> tensor<7x7xf32>
|
||||
return %0 : tensor<7x7xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: linalg_generic_inlined_constant_fill
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
|
||||
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user