[mlir][Linalg] Adding a greedy packing transform dialect op.

This PR adds a `pack_greedily` transform operation that infers the packing for gemm
subcomputations embedded within in any LinalgOp and packs accordingly.
A normalization step guarantees that we get the innermost op dimensions in one of `8`
possible `(m, n, k)` orders, specified as a parameter, from which we can emit all
packed forms.

The current implementation takes an arbitrary LinalgOp and tries to pack it along
the specified dimensions with specified sizes and inner dim permutation.

This achieves a new level of normalization and generalization for any `n-D`
LinalgOp that contains a gemm embedded within it:
we will always see a predictable packed form for any of these ops.

Differential Revision: https://reviews.llvm.org/D142661
This commit is contained in:
Nicolas Vasilache
2023-01-26 12:30:54 -08:00
parent c10615e4a9
commit 55cf0de35e
3 changed files with 584 additions and 14 deletions

View File

@@ -370,7 +370,7 @@ def MultiTileSizesOp : Op<Transform_Dialect, "structured.multitile_sizes",
def PackOp : Op<Transform_Dialect, "structured.pack", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Pack a LinalgOp by applying a data tiling transformation on the op and
packing the operands according to the `packed_sizes` specification.
@@ -453,6 +453,84 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
}];
}
//===----------------------------------------------------------------------===//
// PackGreedilyOp
//===----------------------------------------------------------------------===//
def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let description = [{
Target a Linalg op and rewrite it into packed LinalgOp form by trying to
infer whether a known suboperation is embedded
Different packing strategies are applied in order, when one applies
successfully, the transform returns:
1. Gemm packing: Try to infer a gemm operation embedded in the target op.
Specifically, this looks for 2 parallel dimensions that participate in
an outer-product and 1 reduction dimension.
These dimensions are referred as (m, n, k) to match canonical gemm
terminology.
The packed sizes for (m, n, k) are specified by `gemm_packed_sizes`.
The ordering of the packed dimensions (mm, nn, kk) is specified by the
`gemm_inner_dims_order` attribute.
Packing occurs as follows:
1. Find the dimensions to pack according to the strategy.
2. The target is converted to linalg.generic form.
3. An interchange transform is applied to isolate the dimensions to pack as
the most minor indexing dimensions of the linalg.generic. The most minor
dimensions are themselves ordered according to `inner_dims_order`.
4. Packing is performed by `packed_sizes` and following `inner_dims_order`.
By normalizing the most minor dimensions to `inner_dims_order`, the transform
guarantees that packing immediates generates inner dimensions in a desirable
layout.
Outer dimension layout permutations are not controlled by this transform op
at the moment and can be obtained by composing with the pack_transpose
transformation.
#### Return modes
This operation ignores non-Linalg ops and drops them in the return.
It returns the list of packed Linalg ops or the original op when all available
packing strategies failed to apply.
}];
// TODO: Transform_ConcreteOpType<linalg::LinalgOp> needs interface.
let arguments = (ins TransformHandleTypeInterface:$target,
Variadic<PDL_Operation>:$gemm_packed_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">
:$static_gemm_packed_sizes,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">
:$gemm_inner_dims_order);
let results = (outs Transform_ConcreteOpType<"linalg.generic">:$packed_op);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedGemmPackedSizes,
CArg<"ArrayRef<int64_t>", "{}">:$gemmDimsInnerDimsOrder)>
];
let assemblyFormat = [{
$target
oilist(
`gemm_packed_sizes` `=` custom<DynamicIndexList>($gemm_packed_sizes,
$static_gemm_packed_sizes)
`gemm_inner_dims_order` `=` $gemm_inner_dims_order
)
attr-dict
`:` functional-type($target, results)
}];
let hasVerifier = 1;
let extraClassDeclaration = [{
/// Returns the list of tile sizes, which may be static (Attribute) or
/// dynamic (Value).
SmallVector<OpFoldResult> getMixedGemmPackedSizes();
}];
}
//===----------------------------------------------------------------------===//
// PackTransposeOp
//===----------------------------------------------------------------------===//

View File

@@ -30,9 +30,11 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Debug.h"
@@ -134,7 +136,7 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPDLOperations(
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
transform::DecomposeOp::applyToOne(LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
#define DOWNSCALE(trans) \
@@ -642,7 +644,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
transform::GeneralizeOp::applyToOne(LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Exit early if no transformation is needed.
@@ -663,7 +665,7 @@ transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::InterchangeOp::applyToOne(linalg::GenericOp target,
transform::InterchangeOp::applyToOne(GenericOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
@@ -730,7 +732,7 @@ transform::MatchOp::apply(transform::TransformResults &results,
if (getInterface().has_value()) {
auto iface = getInterface().value();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
!isa<linalg::LinalgOp>(op))
!isa<LinalgOp>(op))
return;
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
@@ -885,7 +887,7 @@ void transform::PackOp::build(OpBuilder &builder, OperationState &result,
// attributes for multiple variadic operands. In the absence of this, horrible
// bugs ensue.
Type linalgOpHType = transform::OperationType::get(
builder.getContext(), linalg::GenericOp::getOperationName());
builder.getContext(), GenericOp::getOperationName());
build(builder, result,
/*resultType=*/linalgOpHType,
/*target=*/target,
@@ -908,7 +910,7 @@ transform::PackOp::apply(transform::TransformResults &transformResults,
return DiagnosedSilenceableFailure::success();
}
// Fail on multi-op handles.
auto linalgOp = dyn_cast<linalg::LinalgOp>(targetOps.front());
auto linalgOp = dyn_cast<LinalgOp>(targetOps.front());
if (targetOps.size() != 1 || !linalgOp) {
return emitSilenceableError()
<< "requires target to map to exactly 1 LinalgOp (got "
@@ -946,6 +948,268 @@ void transform::PackOp::getEffects(
transform::modifiesPayload(effects);
}
//===---------------------------------------------------------------------===//
// PackGreedilyOp.
//===---------------------------------------------------------------------===//
LogicalResult transform::PackGreedilyOp::verify() {
if (!isPermutationVector(getGemmInnerDimsOrder())) {
return emitOpError() << getGemmInnerDimsOrderAttrName()
<< " is not a valid permutation";
}
// TODO: relax to allow empty once we have another strategy than just gemm.
if (getGemmInnerDimsOrder().size() != 3 ||
getMixedGemmPackedSizes().size() != 3) {
return emitOpError() << " needs 3 entries for gemm_packed_sizes and "
<< getGemmInnerDimsOrderAttrName()
<< " order for the gemm strategy";
}
return success();
}
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
/// Return the set of AffineDimExpr
static DenseSet<int64_t>
findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
utils::IteratorType iter) {
DenseSet<int64_t> res;
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = e.dyn_cast<AffineDimExpr>()) {
if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
res.insert(d.getPosition());
}
}
return res;
}
struct GemmDimsForPacking {
int64_t mPos, nPos, kPos;
};
/// Greedily look for 2 parallel (m and n) and 1 reduction (k) dimension that
/// form a gemm. Such dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
/// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
/// 2. The n dimension is involved in an outer-product along RHS
/// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
/// 3. The k dimension appears as a permutation on LHS and RHS.
/// 4. m, n and k appear only once in any given indexing.
///
/// This allows detecting that some gemm is embedded within `linalgOp`.
///
/// When multiple possibilities for selecting m, n and k appear, we just pick
/// an arbitrary one (i.e. the first in a DenseSet).
// TODO: Better heuristic (e.g pick dims based on packing-based metric).
static FailureOr<GemmDimsForPacking> getGemmDims(LinalgOp linalgOp) {
assert(linalgOp.getNumDpsInits() == 1 && "wrong number of dps inits");
assert(linalgOp.getNumDpsInputs() == 2 && "wrong number of dps inputs");
DenseSet<int64_t> a = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), par);
DenseSet<int64_t> b = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), par);
DenseSet<int64_t> c = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInitOperand(0), par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
DenseSet<int64_t> ac = a;
llvm::set_intersect(ac, c);
llvm::set_subtract(ac, b);
// B & C - A are the iterators involved in an outer-product along B (the RHS).
DenseSet<int64_t> bc = b;
llvm::set_intersect(bc, c);
llvm::set_subtract(bc, a);
// Note: if we ever need them, A & B & C would be "batch" dimensions.
// A & B red are the reduction dimensions.
DenseSet<int64_t> ra = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), red);
DenseSet<int64_t> rb = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), red);
llvm::set_intersect(ra, rb);
if (ac.empty() || bc.empty() || ra.empty())
return failure();
// Pick the first one in each set.
// TODO: Better heuristic (e.g pick dims based on packing-based metric).
return GemmDimsForPacking{*ac.begin(), *bc.begin(), *ra.begin()};
}
/// Return a permutation vector of size permSize that would result in moving
/// positions into desiredPositions.
///
/// For example, permSize == 5, positions = {2, 4}, desiredPositions = {1, 0}
/// would result in a {4, 2, 0, 1, 3} permutation vector.
static SmallVector<int64_t>
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
ArrayRef<int64_t> desiredPositions) {
SmallVector<int64_t> res(permSize, -1);
DenseSet<int64_t> seen;
for (auto [pos, desiredPos] : llvm::zip(positions, desiredPositions)) {
res[desiredPos] = pos;
seen.insert(pos);
}
int64_t nextPos = 0;
for (int64_t &entry : res) {
if (entry != -1)
continue;
while (seen.contains(nextPos))
++nextPos;
entry = nextPos;
++nextPos;
}
return res;
}
/// Pack a LinalgOp by greedily inferring gemm dimensions (m, n, k)
/// where m and n are proper parallel dimensions and k is a proper reduction
/// dimension.
/// Packing occurs by rewriting the op as a linalg.generic and calling
/// linalg::pack by `mnkPackedSizes`.
/// The order of the packed dimensions is customizable: the `mnkOrder` is a
/// permutation of {0, 1, 2} to reorder {m, n, k} into one of the 8 possible
/// forms.
/// The outer dimensions of the operands are not permuted at this time, this is
/// left for future work.
static FailureOr<LinalgOp>
packGemmGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
ArrayRef<OpFoldResult> mnkPackedSizes,
ArrayRef<int64_t> mnkOrder) {
assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes");
assert(mnkOrder.size() == 3 && "unexpected mnkOrder size");
assert(isPermutationVector(mnkOrder) && "expected a permutation");
int64_t numLoops = linalgOp.getNumLoops();
if (numLoops <= 2) {
return rewriter.notifyMatchFailure(linalgOp,
"need 3+ loops to find a gemm to pack");
}
// Locally adjust the desired iterator position of mnk and packing sizes.
int64_t numPackedDims = mnkPackedSizes.size();
SmallVector<int64_t> mmnnkkPos(numPackedDims);
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i];
SmallVector<OpFoldResult> packedSizes(mnkPackedSizes.size());
for (int64_t i = 0, e = numPackedDims; i < e; ++i)
packedSizes[mnkOrder[i]] = mnkPackedSizes[i];
// 1. Infer dims that are important for gemm.
FailureOr<GemmDimsForPacking> res = getGemmDims(linalgOp);
if (failed(res)) {
return rewriter.notifyMatchFailure(linalgOp,
"couldn't infer gemm iterators");
}
// 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most
// minor iterators. If we wanted a different normalization order, this is
// where it would have to start.
int64_t mPos = res->mPos, nPos = res->nPos, kPos = res->kPos;
LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL();
DBGS() << "Start packing generic op greedily with (m@" << mPos
<< ", n@" << nPos << ", k@" << kPos << "): " << linalgOp
<< "\n";);
// 2.a. Rewrite as a generic.
auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation());
if (!genericOp) {
FailureOr<GenericOp> generalizeResult =
generalizeNamedOp(rewriter, linalgOp);
assert(succeeded(generalizeResult) && "unexpected failure generalizing op");
genericOp = *generalizeResult;
}
// 2.b. Interchange to move the dimensions (k, m, n) as most-minor iterators.
// Note that this only normalized the iteration order and does not change the
// indexings of any operand.
SmallVector<int64_t> permutation =
computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos);
LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL(););
// Sign .. unsigned pollution.
SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end());
FailureOr<GenericOp> interchangeResult =
interchangeGenericOp(rewriter, genericOp, unsignedPerm);
assert(succeeded(interchangeResult) && "unexpected failure interchanging op");
genericOp = *interchangeResult;
LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";);
// At this point, the op iterators are normalized to {leading, k, m, n}.
// The layouts induced by packing will always be:
// - LHS{leading_lhs, kk, mm}
// - RHS{leading_rhs, kk, nn}
// - RES{leading_res, mm, nn}
// If we wanted to change the packed order, we would reorder (k, m, n) to
// something else above.
//
// Additional permutations of the outer dims of the operands (i.e.
// leading_lhs, leading_rhs and leading_res) could follow by computing the
// desired outerPerm for each operand.
// This is left for future work.
// Add leading zeros to match numLoops.
SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(),
rewriter.getIndexAttr(0));
llvm::append_range(adjustedPackedSizes, packedSizes);
// TODO: If we wanted to give the genericOp a name after packing, after
// calling `pack` would be a good time.
return linalg::pack(rewriter, genericOp, adjustedPackedSizes);
}
DiagnosedSilenceableFailure
PackGreedilyOp::apply(transform::TransformResults &transformResults,
transform::TransformState &state) {
ArrayRef<Operation *> targetOps = state.getPayloadOps(getTarget());
SmallVector<Operation *> results;
IRRewriter rewriter(getContext());
for (Operation *op : targetOps) {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
continue;
// linalgOp will be replaced and the insertion point may be invalidated if
// we set it before -> set it after.
rewriter.setInsertionPointAfter(linalgOp);
// Failing to pack greedily is perfectly fine.
// In the future we will want to order packings according to some metric.
FailureOr<LinalgOp> gemm = packGemmGreedily(
/*rewriter=*/rewriter,
/*linalgOp=*/linalgOp,
/*mnkPackedSizes=*/getMixedGemmPackedSizes(),
/*mnkOrder=*/getGemmInnerDimsOrder());
if (succeeded(gemm)) {
results.push_back(*gemm);
continue;
}
results.push_back(linalgOp);
}
transformResults.set(getPackedOp().cast<OpResult>(), results);
return DiagnosedSilenceableFailure::success();
}
SmallVector<OpFoldResult> PackGreedilyOp::getMixedGemmPackedSizes() {
Builder b(getContext());
return getMixedValues(getStaticGemmPackedSizes(), getGemmPackedSizes(), b);
}
void transform::PackGreedilyOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getTarget(), effects);
transform::onlyReadsHandle(getGemmPackedSizes(), effects);
transform::producesHandle(getPackedOp(), effects);
transform::modifiesPayload(effects);
}
//===---------------------------------------------------------------------===//
// PackTransposeOp
//===---------------------------------------------------------------------===//
@@ -1030,7 +1294,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
return emitSilenceableError() << "requires target to map to a "
"tensor.pack or tensor.unpack";
}
LinalgOp linalgOpTarget = dyn_cast<linalg::LinalgOp>(linalgOps.front());
LinalgOp linalgOpTarget = dyn_cast<LinalgOp>(linalgOps.front());
if (!linalgOpTarget)
return emitSilenceableError() << "requires a LinalgOp target";
@@ -1102,7 +1366,7 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults,
//===---------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::PadOp::applyToOne(linalg::LinalgOp target,
transform::PadOp::applyToOne(LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
// Convert the integer packing flags to booleans.
@@ -1214,7 +1478,7 @@ LogicalResult transform::PadOp::verify() {
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::PromoteOp::applyToOne(linalg::LinalgOp target,
transform::PromoteOp::applyToOne(LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
LinalgPromotionOptions promotionOptions;
@@ -1308,7 +1572,7 @@ LogicalResult transform::ReplaceOp::verify() {
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
transform::ScalarizeOp::applyToOne(LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
scf::SCFTilingOptions tilingOptions;
@@ -1560,7 +1824,7 @@ void transform::SplitReductionOp::build(
}
DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne(
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
ControlSplitReductionFn splitFn = [&](LinalgOp) {
return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
@@ -1605,7 +1869,7 @@ void transform::TileReductionUsingScfOp::build(
}
DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);
@@ -1649,7 +1913,7 @@ void transform::TileReductionUsingForeachThreadOp::build(
DiagnosedSilenceableFailure
transform::TileReductionUsingForeachThreadOp::applyToOne(
linalg::LinalgOp target, transform::ApplyToEachResultList &results,
LinalgOp target, transform::ApplyToEachResultList &results,
transform::TransformState &state) {
TrivialPatternRewriter rewriter(getContext());
rewriter.setInsertionPoint(target);

View File

@@ -0,0 +1,228 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
!A_mk = tensor<1023x255xf32>
!B_kn = tensor<255x127xf32>
!C_mn = tensor<1023x127xf32>
// Normalized dims are: ( k, m, n)(kk, mm, nn)
// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
// CHECK-LABEL: @matmul_mk_kn_mn(
func.func @matmul_mk_kn_mn(%A : !A_mk, %B : !B_kn, %C : !C_mn) -> !C_mn {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
// CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<128x8x8x16xf32>)
%0 = linalg.matmul ins(%A, %B : !A_mk, !B_kn) outs(%C : !C_mn) -> !C_mn
return %0 : !C_mn
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%matmul = transform.structured.match ops{["linalg.matmul"]} in %module_op
: (!pdl.operation) -> !transform.op<"linalg.matmul">
transform.structured.pack_greedily %matmul
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.matmul">) -> !transform.op<"linalg.generic">
}
// -----
!A_mk = tensor<1023x255xf32>
!B_nk = tensor<127x255xf32>
!C_nm = tensor<127x1023xf32>
#mkn_accesses = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (n, k)>,
affine_map<(m, n, k) -> (n, m)>
]
#mkn_trait = {
indexing_maps = #mkn_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
// Normalized dims are: ( k, m, n)(kk, mm, nn)
// CHECK-DAG: #[[$km_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
// CHECK-LABEL: @matmul_mk_nk_nm(
func.func @matmul_mk_nk_nm(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
// CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>)
%0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
^bb0(%a: f32, %b: f32, %c: f32):
%d = arith.mulf %a, %b : f32
%e = arith.addf %c, %d : f32
linalg.yield %e : f32
} -> !C_nm
return %0 : !C_nm
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
// -----
!A_mk = tensor<1023x255xf32>
!B_nk = tensor<127x255xf32>
!C_nm = tensor<127x1023xf32>
#mkn_accesses = [
affine_map<(k, m, n) -> (m, k)>,
affine_map<(k, m, n) -> (n, k)>,
affine_map<(k, m, n) -> (n, m)>
]
#mkn_trait = {
indexing_maps = #mkn_accesses,
iterator_types = ["reduction", "parallel", "parallel"]
}
// Normalized dims are: ( k, m, n)(kk, mm, nn)
// CHECK-DAG: #[[$mk_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>
// CHECK-DAG: #[[$kn_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
// CHECK-DAG: #[[$mn_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d4, d5)>
// CHECK-LABEL: @matmul_mk_nk_nm_transposed(
func.func @matmul_mk_nk_nm_transposed(%A : !A_mk, %B : !B_nk, %C : !C_nm) -> !C_nm {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$mk_kkmm]], #[[$kn_kknn]], #[[$mn_mmnn]]]
// CHECK-SAME: ["reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.*}} : tensor<128x8x32x8xf32>, tensor<8x8x32x16xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<8x128x8x16xf32>)
%0 = linalg.generic #mkn_trait ins(%A, %B : !A_mk, !B_nk) outs(%C : !C_nm) {
^bb0(%a: f32, %b: f32, %c: f32):
%d = arith.mulf %a, %b : f32
%e = arith.addf %c, %d : f32
linalg.yield %e : f32
} -> !C_nm
return %0 : !C_nm
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
// -----
!A_bmkm2 = tensor<42x1023x255x33xf32>
!B_nkb = tensor<127x255x42xf32>
!C_nbm = tensor<127x42x1023xf32>
#mkn_accesses = [
affine_map<(k, m, n, b, m2) -> (b, m, k, m2)>,
affine_map<(k, m, n, b, m2) -> (n, k, b)>,
affine_map<(k, m, n, b, m2) -> (n, b, m)>
]
#mkn_trait = {
indexing_maps = #mkn_accesses,
iterator_types = ["reduction", "parallel", "parallel", "parallel", "parallel"]
}
// Normalized dims are: ( ?, ?, k, m, n)(kk, mm, nn)
// CHECK-DAG: #[[$bmkm2_kkmm:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d3, d2, d1, d5, d6)>
// CHECK-DAG: #[[$nkb_kknn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d2, d0, d5, d7)>
// CHECK-DAG: #[[$nbm_mmnn:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d0, d3, d6, d7)>
// CHECK-LABEL: @contraction_bmkm2_nkb_nbm(
func.func @contraction_bmkm2_nkb_nbm(%A : !A_bmkm2, %B : !B_nkb, %C : !C_nbm) -> !C_nbm {
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$bmkm2_kkmm]], #[[$nkb_kknn]], #[[$nbm_mmnn]]]
// CHECK-SAME: ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]}
// CHECK-SAME: ins(%{{.*}} : tensor<42x128x8x33x32x8xf32>, tensor<8x8x42x32x16xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<8x42x128x8x16xf32>)
%0 = linalg.generic #mkn_trait ins(%A, %B : !A_bmkm2, !B_nkb) outs(%C : !C_nbm) {
^bb0(%a: f32, %b: f32, %c: f32):
%d = arith.mulf %a, %b : f32
%e = arith.addf %c, %d : f32
linalg.yield %e : f32
} -> !C_nbm
return %0 : !C_nbm
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}
// -----
// Conv linguo: h w kh kw c n f cc nn ff
// Normalized dims are: ( ?, ?, ?, ?, k, m, n)(kk, mm, nn)
// n c h + kh w + kw cc nn
// CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d4, d0 + d2, d1 + d3, d7, d8)>
// f c kh kw cc ff
// CHECK-DAG: #[[$M2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d6, d4, d2, d3, d7, d9)>
// n f h w nn ff
// CHECK-DAG: #[[$M3:.*]] = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9) -> (d5, d6, d0, d1, d8, d9)>
// CHECK-LABEL: @conv_2d_nchw_fchw
func.func @conv_2d_nchw_fchw(%arg0: tensor<?x47x16x16xf32>, %arg2: tensor<?x16x14x14xf32>) -> tensor<?x16x14x14xf32> {
%c0 = arith.constant dense<0.1> : tensor<16x47x3x3xf32>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$M1]], #[[$M2]], #[[$M3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "reduction", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel"]
// CHECK-SAME: ins(%{{.*}} : tensor<?x2x16x16x32x8xf32>, tensor<1x2x3x3x32x16xf32>)
// CHECK-SAME: outs(%{{.*}} : tensor<?x1x14x14x8x16xf32>)
%0 = linalg.conv_2d_nchw_fchw
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %c0: tensor<?x47x16x16xf32>, tensor<16x47x3x3xf32>)
outs(%arg2: tensor<?x16x14x14xf32>) -> tensor<?x16x14x14xf32>
return %0 : tensor<?x16x14x14xf32>
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%conv = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %module_op
: (!pdl.operation) -> !transform.op<"linalg.conv_2d_nchw_fchw">
transform.structured.pack_greedily %conv
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.conv_2d_nchw_fchw">) -> !transform.op<"linalg.generic">
}
// -----
// These should fail to pack for now as they don't contain a contraction.
// CHECK-LABEL: @reduce_and_map
func.func @reduce_and_map(%arg0: tensor<10x100xf32>,
%arg1: tensor<10x100xf32>, %output: tensor<10xf32>) -> tensor<10xf32> {
%map_init = tensor.empty() : tensor<10x100xf32>
// CHECK: linalg.map
%mapped = linalg.map { arith.addf }
ins(%arg0, %arg1 : tensor<10x100xf32>, tensor<10x100xf32>)
outs(%map_init : tensor<10x100xf32>)
// CHECK: linalg.reduce
%res = linalg.reduce { arith.addf }
ins(%mapped: tensor<10x100xf32>)
outs(%output: tensor<10xf32>)
dimensions = [1]
return %res : tensor<10xf32>
}
transform.sequence failures(propagate) {
^bb1(%module_op: !pdl.operation):
%generic = transform.structured.match ops{["linalg.generic"]} in %module_op : (!pdl.operation) -> !transform.op<"linalg.generic">
transform.structured.pack_greedily %generic
gemm_packed_sizes = [8, 16, 32] gemm_inner_dims_order = [1, 2, 0]
: (!transform.op<"linalg.generic">) -> !transform.op<"linalg.generic">
}