Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve

multiple TODOs.

- replace the fake test pass (that worked on just the first loop in the
  MLFunction) to perform DMA pipelining on all suitable loops.
- nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops)
- fix bugs / assumptions: correctly copy memory space and elemental type of source
  memref for double buffering.
- correctly identify matching start/finish statements, handle multiple DMAs per
  loop.
- introduce dominates/properlyDominates utitilies for MLFunction statements.
- move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it
  getShiftValidity
- refactor getContainingStmtPos -> findAncestorStmtInBlock - move into
  Analysis/Utils.h; has two users.
- other improvements / cleanup for related API/utilities
- add size argument to dma_wait - for nested DMAs or in general, it makes it
  easy to obtain the size to use when lowering the dma_wait since we wouldn't
  want to identify the matching dma_start, and more importantly, in general/in the
  future, there may not always be a dma_start dominating the dma_wait.
- add debug information in the pass

PiperOrigin-RevId: 217734892
This commit is contained in:
Uday Bondhugula
2018-10-18 11:14:26 -07:00
committed by jpienaar
parent 3013dadb7c
commit 18e666702c
14 changed files with 494 additions and 191 deletions

View File

@@ -22,6 +22,7 @@
#ifndef MLIR_ANALYSIS_LOOP_ANALYSIS_H
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
namespace mlir {
@@ -52,6 +53,13 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
// vectorizable. A function over the actions will give us a cost model.
bool isVectorizableLoop(const ForStmt &loop);
/// Checks where SSA dominance would be violated if a for stmt's body statements
/// are shifted by the specified shifts. This method checks if a 'def' and all
/// its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool isStmtwiseShiftValid(const ForStmt &forStmt,
llvm::ArrayRef<uint64_t> shifts);
} // end namespace mlir
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H

View File

@@ -0,0 +1,40 @@
//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This header file defines prototypes for various transformation utilities for
// memref's and non-loop IR structures. These are not passes by themselves but
// are used either by passes, optimization sequences, or in turn by other
// transformation utilities.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_UTILS_H
#define MLIR_ANALYSIS_UTILS_H
namespace mlir {
class Statement;
/// Returns true if statement 'a' dominates statement b.
bool dominates(const Statement &a, const Statement &b);
/// Returns true if statement 'a' properly dominates statement b.
bool properlyDominates(const Statement &a, const Statement &b);
} // end namespace mlir
#endif // MLIR_ANALYSIS_UTILS_H

View File

@@ -105,6 +105,26 @@ public:
void printBlock(raw_ostream &os) const;
void dumpBlock() const;
/// Returns the statement's position in this block or -1 if the statement is
/// not present.
int findStmtPosInBlock(const Statement &stmt) const {
unsigned j = 0;
for (const auto &s : statements) {
if (&s == &stmt)
return j;
j++;
}
return -1;
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the
/// ancestor statement of 'stmt' that lies in this block. Returns nullptr if
/// the latter fails.
const Statement *findAncestorStmtInBlock(const Statement &stmt) const;
Statement *findAncestorStmtInBlock(Statement *stmt) {
return const_cast<Statement *>(findAncestorStmtInBlock(*stmt));
}
protected:
StmtBlock(StmtBlockKind kind) : kind(kind) {}

View File

@@ -310,6 +310,25 @@ public:
getOperation()->operand_end()};
}
/// Returns true if this is a DMA from a faster memory space to a slower one.
bool isDestMemorySpaceFaster() const {
return (getSrcMemorySpace() < getDstMemorySpace());
}
/// Returns true if this is a DMA from a slower memory space to a faster one.
bool isSrcMemorySpaceFaster() const {
// Assumes that a lower number is for a slower memory space.
return (getDstMemorySpace() < getSrcMemorySpace());
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy. Asserts failure if neither is true.
unsigned getFasterMemPos() const {
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
}
static StringRef getOperationName() { return "dma_start"; }
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@@ -321,8 +340,9 @@ protected:
// DmaWaitOp blocks until the completion of a DMA operation associated with the
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
// with the same restrictions as any load/store index in MLFunctions. For
// example:
// with the same restrictions as any load/store index in MLFunctions.
// %num_elements is the number of elements associated with the DMA operation.
// For example:
//
// dma_start %src[%i, %j], %dst[%k, %l], %tag[%index] :
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
@@ -330,7 +350,7 @@ protected:
// memref<1 x i32>, (d0) -> (d0), 4>
// ...
// ...
// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
//
class DmaWaitOp
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
@@ -344,7 +364,18 @@ public:
// Returns the tag memref index for this DMA operation.
llvm::iterator_range<Operation::const_operand_iterator>
getTagIndices() const {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
return {getOperation()->operand_begin() + 1,
getOperation()->operand_begin() + 1 + getTagMemRefRank()};
}
// Returns the rank (number of indices) of the tag memref.
unsigned getTagMemRefRank() const {
return cast<MemRefType>(getTagMemRef()->getType())->getRank();
}
// Returns the number of elements transferred in the associated DMA operation.
const SSAValue *getNumElements() const {
return getOperand(1 + getTagMemRefRank());
}
protected:

View File

@@ -86,10 +86,6 @@ AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
bool unrollPrologueEpilogue = false);
/// Checks if SSA dominance would be violated if a for stmt's child statements
/// are shifted by the specified delays.
bool checkDominancePreservationOnShift(const ForStmt &forStmt,
ArrayRef<uint64_t> delays);
} // end namespace mlir

View File

@@ -45,7 +45,7 @@ class SSAValue;
/// Additional indices are added at the start.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
llvm::ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap = AffineMap::Invalid());

View File

@@ -24,8 +24,6 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/AffineStructures.h"
#include "mlir/Analysis/MLFunctionMatcher.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Statements.h"
@@ -128,12 +126,13 @@ static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
assert(indices.size() == memRefType->getRank());
assert(dim < indices.size());
auto layoutMap = memRefType->getAffineMaps();
assert(layoutMap.size() <= 1);
assert(memRefType->getAffineMaps().size() <= 1);
// TODO(ntv): remove dependency on Builder once we support non-identity
// layout map.
Builder b(memRefType->getContext());
assert(layoutMap.empty() ||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
(void)layoutMap;
SmallVector<OperationStmt *, 4> affineApplyOps;
getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
@@ -197,3 +196,35 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
}
return true;
}
/// Checks whether SSA dominance would be violated if a for stmt's body
/// statements are shifted by the specified shifts. This method checks if a
/// 'def' and all its uses have the same shift factor.
// TODO(mlir-team): extend this to check for memory-based dependence
// violation when we have the support.
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
ArrayRef<uint64_t> shifts) {
assert(shifts.size() == forStmt.getStatements().size());
unsigned s = 0;
for (const auto &stmt : forStmt) {
// A for or if stmt does not produce any def/results (that are used
// outside).
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const MLValue *result = opStmt->getResult(i);
for (const StmtOperand &use : result->getUses()) {
// If an ancestor statement doesn't lie in the block of forStmt, there
// is no shift to check.
// This is a naive way. If performance becomes an issue, a map can
// be used to store 'shifts' - to look up the shift for a statement in
// constant time.
if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
return false;
}
}
}
s++;
}
return true;
}

View File

@@ -0,0 +1,62 @@
//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements miscellaneous analysis routines for non-loop IR
// structures.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
/// Returns true if statement 'a' properly dominates statement b.
bool mlir::properlyDominates(const Statement &a, const Statement &b) {
if (&a == &b)
return false;
if (a.findFunction() != b.findFunction())
return false;
if (a.getBlock() == b.getBlock()) {
// Do a linear scan to determine whether b comes after a.
auto aIter = StmtBlock::const_iterator(a);
auto bIter = StmtBlock::const_iterator(b);
auto aBlockStart = a.getBlock()->begin();
while (bIter != aBlockStart) {
--bIter;
if (aIter == bIter)
return true;
}
return false;
}
// Traverse up b's hierarchy to check if b's block is contained in a's.
if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b))
// a and bAncestor are in the same block; check if the former dominates it.
return dominates(a, *bAncestor);
// b's block is not contained in A.
return false;
}
/// Returns true if statement A dominates statement B.
bool mlir::dominates(const Statement &a, const Statement &b) {
return &a == &b || properlyDominates(a, b);
}

View File

@@ -45,3 +45,19 @@ MLFunction *StmtBlock::findFunction() const {
}
return dyn_cast<MLFunction>(block);
}
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
/// statement of 'stmt' that lies in this block. Returns nullptr if the latter
/// fails.
const Statement *
StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
const auto *currStmt = &stmt;
while (currStmt->getBlock() != this) {
currStmt = currStmt->getParentStmt();
if (!currStmt)
return nullptr;
}
return currStmt;
}

View File

@@ -392,7 +392,7 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
}
// Parse DmaStartOp.
// EX:
// Ex:
// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
// %tag[%index] :
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
@@ -458,33 +458,38 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
// ---------------------------------------------------------------------------
// DmaWaitOp
// ---------------------------------------------------------------------------
// Parse DmaWaitOp.
// Eg:
// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
//
void DmaWaitOp::print(OpAsmPrinter *p) const {
*p << getOperationName() << ' ';
// Print operands.
p->printOperand(getTagMemRef());
*p << '[';
p->printOperands(getTagIndices());
*p << ']';
*p << "], ";
p->printOperand(getNumElements());
*p << " : " << *getTagMemRef()->getType();
}
// Parse DmaWaitOp.
// Eg:
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
//
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType tagMemrefInfo;
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
Type *type;
auto *indexType = parser->getBuilder().getIndexType();
OpAsmParser::OperandType numElementsInfo;
// Parse tag memref and index.
// Parse tag memref, its indices, and dma size.
if (parser->parseOperand(tagMemrefInfo) ||
parser->parseOperandList(tagIndexInfos, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
parser->parseColonType(type) ||
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
return true;
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())

View File

@@ -181,57 +181,6 @@ generateLoop(AffineMap lb, AffineMap ub,
return loopChunk;
}
// Returns delay of that child statement of 'forStmt' which either has 'operand'
// as one of its operands or has a descendant statement with operand 'operand'.
// This is a naive implementation. If performance becomes an issue, a map can
// be used to store 'delays' - to look up the delay for a statement in constant
// time.
static uint64_t getContainingStmtDelay(const StmtOperand &operand,
const ForStmt &forStmt,
ArrayRef<uint64_t> delays) {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
const Statement *stmt = operand.getOwner();
assert(stmt != nullptr);
while (stmt->getParentStmt() != &forStmt) {
stmt = stmt->getParentStmt();
assert(stmt && "traversing parent's should reach forStmt block");
}
// Look up the delay of 'stmt'.
unsigned j = 0;
for (const auto &s : forStmt) {
if (&s == stmt)
break;
j++;
}
assert(j < forStmt.getStatements().size() && "child stmt should be found");
return delays[j];
}
/// Checks if SSA dominance would be violated if a for stmt's body statements
/// are shifted by the specified delays. This method checks if a 'def' and all
/// its uses have the same delay factor.
bool mlir::checkDominancePreservationOnShift(const ForStmt &forStmt,
ArrayRef<uint64_t> delays) {
assert(delays.size() == forStmt.getStatements().size());
unsigned s = 0;
for (const auto &stmt : forStmt) {
// A for or if stmt does not produce any def/results (that are used
// outside).
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
const MLValue *result = opStmt->getResult(i);
for (const StmtOperand &use : result->getUses()) {
if (delays[s] != getContainingStmtDelay(use, forStmt, delays))
return false;
}
}
}
s++;
}
return true;
}
/// Skew the statements in the body of a 'for' statement with the specified
/// statement-wise delays. The delays are with respect to the original execution
/// order. A delay of zero for each statement will lead to no change.
@@ -260,7 +209,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
return UtilResult::Failure;
uint64_t tripCount = mayBeConstTripCount.getValue();
assert(checkDominancePreservationOnShift(*forStmt, delays) &&
assert(isStmtwiseShiftValid(*forStmt, delays) &&
"dominance preservation failed\n");
unsigned numChildStmts = forStmt->getStatements().size();

View File

@@ -22,21 +22,31 @@
#include "mlir/Transforms/Passes.h"
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/LoopAnalysis.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "pipeline-data-transfer"
using namespace mlir;
namespace {
struct PipelineDataTransfer : public MLFunctionPass {
explicit PipelineDataTransfer() {}
struct PipelineDataTransfer : public MLFunctionPass,
StmtWalker<PipelineDataTransfer> {
PassResult runOnMLFunction(MLFunction *f) override;
PassResult runOnForStmt(ForStmt *forStmt);
// Collect all 'for' statements.
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
std::vector<ForStmt *> forStmts;
};
} // end anonymous namespace
@@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
unsigned srcDmaPos = 0;
unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
return srcDmaPos;
return destDmaPos;
}
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
@@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
/// Doubles the buffer of the supplied memref while replacing all uses of the
/// old memref. Returns false if such a replacement cannot be performed.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
// Add the leading dimension in the shape for the double buffer.
ArrayRef<int> shape = origMemRefType->getShape();
ArrayRef<int> shape = oldMemRefType->getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
auto *newMemRefType =
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
oldMemRefType->getMemorySpace());
return newMemRefType;
};
@@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
cast<MLValue>(ivModTwoOp->getResult(0))))
cast<MLValue>(ivModTwoOp->getResult(0)))) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
return false;
}
return true;
}
// For testing purposes, this just runs on the first 'for' statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
/// Returns false if this succeeds on at least one 'for' stmt.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
ForStmt *forStmt = nullptr;
for (auto &stmt : *f) {
if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
break;
}
// Do a post order walk so that inner loop DMAs are processed first. This is
// necessary since 'for' statements nested within would otherwise become
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forStmts.clear();
walkPostOrder(f);
bool ret = true;
for (auto *forStmt : forStmts) {
ret = ret & runOnForStmt(forStmt);
}
if (!forStmt)
return PassResult::Success;
return ret ? failure() : success();
}
unsigned numStmts = forStmt->getStatements().size();
// Check if tags of the dma start op and dma wait op match.
static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
OpPointer<DmaWaitOp> waitOp) {
if (startOp->getTagMemRef() != waitOp->getTagMemRef())
return false;
auto startIndices = startOp->getTagIndices();
auto waitIndices = waitOp->getTagIndices();
// Both of these have the same number of indices since they correspond to the
// same tag memref.
for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
e = startIndices.end();
it != e; ++it, ++wIt) {
// Keep it simple for now, just checking if indices match.
// TODO(mlir-team): this would in general need to check if there is no
// intervening write writing to the same tag location, i.e., memory last
// write/data flow analysis. This is however sufficient/powerful enough for
// now since the DMA generation pass or the input for it will always have
// start/wait with matching tags (same SSA operand indices).
if (*it != *wIt)
return false;
}
return true;
}
if (numStmts == 0)
return PassResult::Success;
SmallVector<OperationStmt *, 4> dmaStartStmts;
SmallVector<OperationStmt *, 4> dmaFinishStmts;
// Identify matching DMA start/finish statements to overlap computation with.
static void findMatchingStartFinishStmts(
ForStmt *forStmt,
SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
&startWaitPairs) {
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
if (opStmt->is<DmaStartOp>()) {
dmaStartStmts.push_back(opStmt);
} else if (opStmt->is<DmaWaitOp>()) {
// Collect DMA finish statements.
if (opStmt->is<DmaWaitOp>()) {
dmaFinishStmts.push_back(opStmt);
continue;
}
OpPointer<DmaStartOp> dmaStartOp;
if (!(dmaStartOp = opStmt->getAs<DmaStartOp>()))
continue;
// Only DMAs incoming into higher memory spaces.
// TODO(bondhugula): outgoing DMAs.
if (!dmaStartOp->isDestMemorySpaceFaster())
continue;
// We only double buffer if the buffer is not live out of loop.
const MLValue *memref =
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
if (!dominates(*forStmt, *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
break;
}
}
if (!escapingUses)
dmaStartStmts.push_back(opStmt);
}
// TODO(bondhugula,andydavis): match tag memref's (requires memory-based
// subscript check utilities). Assume for now that start/finish are matched in
// the order they appear.
if (dmaStartStmts.size() != dmaFinishStmts.size())
// For each start statement, we look for a matching finish statement.
for (auto *dmaStartStmt : dmaStartStmts) {
for (auto *dmaFinishStmt : dmaFinishStmts) {
if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(),
dmaFinishStmt->getAs<DmaWaitOp>())) {
startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
break;
}
}
}
}
/// Overlap DMA transfers with computation in this loop. If successful,
/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
/// inserted right before where it was.
PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
if (!mayBeConstTripCount.hasValue()) {
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
return PassResult::Failure;
}
SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
findMatchingStartFinishStmts(forStmt, startWaitPairs);
if (startWaitPairs.empty()) {
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
return failure();
}
// Double the buffers for the higher memory space memref's.
// TODO(bondhugula): assuming we don't have multiple DMA starts for the same
// memref.
// Identify memref's to replace by scanning through all DMA start statements.
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
// Identify memref's to replace by scanning through all DMA start statements.
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos()));
if (!doubleBuffer(oldMemRef, forStmt)) {
return PassResult::Failure;
// Normally, double buffering should not fail because we already checked
// that there are no uses outside.
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
LLVM_DEBUG(dmaStartStmt->dump());
return failure();
}
}
// Double the buffers for tag memref's.
for (auto *dmaFinishStmt : dmaFinishStmts) {
MLValue *oldTagMemRef = cast<MLValue>(
// Double the buffers for tag memrefs.
for (auto &pair : startWaitPairs) {
const auto *dmaFinishStmt = pair.second;
const MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt)) {
return PassResult::Failure;
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
return failure();
}
}
// Collect all compute ops.
std::vector<const Statement *> computeOps;
computeOps.reserve(forStmt->getStatements().size());
// Double buffering would have invalidated all the old DMA start/wait stmts.
startWaitPairs.clear();
findMatchingStartFinishStmts(forStmt, startWaitPairs);
// Store delay for statement for later lookup for AffineApplyOp's.
DenseMap<const Statement *, unsigned> opDelayMap;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt) {
// All for and if stmt's are treated as pure compute operations.
opDelayMap[&stmt] = 1;
} else if (opStmt->is<DmaStartOp>()) {
// DMA starts are not shifted.
opDelayMap[opStmt] = 0;
// Set shifts for DMA start stmt's affine operand computation slices to 0.
if (auto *slice = mlir::createAffineComputationSlice(opStmt)) {
opDelayMap[slice] = 0;
} else {
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
SmallVector<OperationStmt *, 4> affineApplyStmts;
SmallVector<MLValue *, 4> operands(opStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (auto *op : affineApplyStmts) {
opDelayMap[op] = 0;
}
}
} else if (opStmt->is<DmaWaitOp>()) {
// DMA finish op shifted by one.
opDelayMap[opStmt] = 1;
DenseMap<const Statement *, unsigned> stmtDelayMap;
for (auto &pair : startWaitPairs) {
auto *dmaStartStmt = pair.first;
assert(dmaStartStmt->is<DmaStartOp>());
stmtDelayMap[dmaStartStmt] = 0;
// Set shifts for DMA start stmt's affine operand computation slices to 0.
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
stmtDelayMap[slice] = 0;
} else {
// Everything else is a compute op; so shifted by one (op's supplying
// 'affine' operands to DMA start's have already been set right shifts.
opDelayMap[opStmt] = 1;
computeOps.push_back(&stmt);
// If a slice wasn't created, the reachable affine_apply op's from its
// operands are the ones that go with it.
SmallVector<OperationStmt *, 4> affineApplyStmts;
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
getReachableAffineApplyOps(operands, affineApplyStmts);
for (const auto *stmt : affineApplyStmts) {
stmtDelayMap[stmt] = 0;
}
}
}
// Everything else (including compute ops and dma finish) are shifted by one.
for (const auto &stmt : *forStmt) {
if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
stmtDelayMap[&stmt] = 1;
}
}
@@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0;
for (const auto &stmt : *forStmt) {
assert(opDelayMap.find(&stmt) != opDelayMap.end());
delays[s++] = opDelayMap[&stmt];
assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
delays[s++] = stmtDelayMap[&stmt];
}
if (!checkDominancePreservationOnShift(*forStmt, delays)) {
if (!isStmtwiseShiftValid(*forStmt, delays)) {
// Violates SSA dominance.
LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";);
return PassResult::Failure;
}
if (stmtBodySkew(forStmt, delays)) {
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";);
return PassResult::Failure;
}
return PassResult::Success;
return success();
}

View File

@@ -48,7 +48,8 @@ static bool isMemRefDereferencingOp(const Operation &op) {
/// at the start for now.
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
// extended to add additional indices at any position.
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
MLValue *newMemRef,
ArrayRef<MLValue *> extraIndices,
AffineMap indexRemap) {
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
@@ -219,11 +220,11 @@ mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
/// This allows applying different transformations on send and compute (for eg.
/// different shifts/delays).
///
/// Returns nullptr if none of the operands were the result of an affine_apply
/// and thus there was no affine computation slice to create. Returns the newly
/// affine_apply operation statement otherwise.
///
///
/// Returns nullptr either if none of opStmt's operands were the result of an
/// affine_apply and thus there was no affine computation slice to create, or if
/// all the affine_apply op's supplying operands to this opStmt do not have any
/// uses besides this opStmt. Returns the new affine_apply operation statement
/// otherwise.
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
// Collect all operands that are results of affine apply ops.
SmallVector<MLValue *, 4> subOperands;

View File

@@ -1,30 +1,28 @@
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
// CHECK: #map0 = (d0) -> (d0 mod 2, d0 mod 2)
// CHECK-NEXT: #map1 = (d0) -> (d0 - 1)
// CHECK-NEXT: #map2 = (d0) -> (d0 mod 2)
// CHECK-NEXT: mlfunc @loop_nest_dma() {
// CHECK-NEXT: %c8 = constant 8 : index
// CHECK-LABEL: mlfunc @loop_nest_dma() {
mlfunc @loop_nest_dma() {
// CHECK: %c8 = constant 8 : index
// CHECK-NEXT: %c0 = constant 0 : index
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32>
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32, 1>
// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)>
// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1>
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
// CHECK-NEXT: %c0_0 = constant 0 : index
// CHECK-NEXT: %c128 = constant 128 : index
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: for %i0 = 1 to 7 {
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: %7 = affine_apply #map1(%i0)
// CHECK-NEXT: %8 = affine_apply #map2(%7)
// CHECK-NEXT: %9 = affine_apply #map2(%7)
// CHECK-NEXT: dma_wait %0[%8, %c0_0] : memref<2x1xf32>
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32>
// CHECK-NEXT: dma_wait %0[%8, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1>
// CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32>
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1>
// CHECK-NEXT: for %i1 = 0 to 127 {
// CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> ()
// CHECK-NEXT: }
@@ -32,15 +30,14 @@
// CHECK-NEXT: %12 = affine_apply #map1(%c8)
// CHECK-NEXT: %13 = affine_apply #map2(%12)
// CHECK-NEXT: %14 = affine_apply #map2(%12)
// CHECK-NEXT: dma_wait %0[%13, %c0_0] : memref<2x1xf32>
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32>
// CHECK-NEXT: dma_wait %0[%13, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1>
// CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32>
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1>
// CHECK-NEXT: for %i2 = 0 to 127 {
// CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: return
mlfunc @loop_nest_dma() {
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
@@ -48,11 +45,11 @@ mlfunc @loop_nest_dma() {
%tag = alloc() : memref<1 x f32>
%zero = constant 0 : index
%size = constant 128 : index
%num_elts = constant 128 : index
for %i = 0 to 7 {
dma_start %A[%i], %Ah[%i], %size, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32>
dma_wait %tag[%zero] : memref<1 x f32>
dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32>
dma_wait %tag[%zero], %num_elts : memref<1 x f32>
%v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
%r = "compute"(%v) : (f32) -> (f32)
store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
@@ -62,3 +59,76 @@ mlfunc @loop_nest_dma() {
}
return
}
#map0 = (d0, d1) -> (d0, d1)
#map1 = (d0, d1) -> ((d0 * 2048 + d1 * 256) floordiv 32, 0)
#map2 = (d0) -> ((d0 * 2048) floordiv 32, 0)
// CHECK: mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>
mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : memref<512x32xvector<8xf32>, #map0>, %arg2 : memref<512x32xvector<8xf32>, #map0>) {
%num_elts = constant 256 : index
%c0 = constant 0 : index
%0 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%1 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%2 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
%3 = alloc() : memref<2xi32>
%4 = alloc() : memref<2xi32>
%5 = alloc() : memref<2xi32>
// Prologue for DMA overlap on arg2.
// CHECK: dma_start %arg2[
// CHECK-NEXT: for %i0 = 1 to 7 {
for %i0 = 0 to 7 {
%6 = affine_apply #map2(%i0)
dma_start %arg2[%6#0, %6#1], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
dma_wait %5[%c0], %num_elts : memref<2xi32>
// Steady state for DMA overlap on arg2
// CHECK: dma_start %arg2[
// CHECK: dma_wait %0[
// Prologue for DMA overlap on arg0, arg1 nested within i0
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// CHECK-NEXT for %i1 = 1 to 7 {
for %i1 = 0 to 7 {
%7 = affine_apply #map1(%i0, %i1)
%8 = affine_apply #map2(%i1)
dma_start %arg0[%7#0, %7#1], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
dma_start %arg1[%8#0, %8#1], %1[%c0, %c0], %num_elts, %4[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
dma_wait %3[%c0], %num_elts : memref<2xi32>
dma_wait %4[%c0], %num_elts : memref<2xi32>
// Steady state for DMA overlap on arg0, arg1
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// CHECK: dma_wait %3[
// CHECK: dma_wait %2[
// CHECK-NEXT: for %i2 = 0 to 3 {
for %i2 = 0 to 3 {
"foo"() : () -> ()
}
}
// epilogue for arg0, arg1
// CHECK: dma_wait %3[
// CHECK: dma_wait %2[
// epilogue for DMA overlap on %arg2
// CHECK: dma_wait %0[%37, %c0_2], %c256 : memref<2x2xi32>
// Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// CHECK: for %i4 = 1 to 7 {
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// CHECK: dma_wait %3[
// CHECK: dma_wait %2[
// CHECK: for %i5 = 0 to 3 {
// CHECK: "foo"() : () -> ()
// CHECK: dma_wait %3[
// CHECK: dma_wait %2[
// CHECK: for %i6 = 0 to 3 {
// The DMAs below are outgoing DMAs on arg2, not yet overlapped.
// CHECK: dma_start %1{{.*}}, %arg2[
// CHECK-NEXT: dma_wait %0[
dma_start %2[%c0, %c0], %arg2[%6#0, %6#1], %num_elts, %5[%c0] : memref<64x4xvector<8xf32>, #map0, 2>, memref<512x32xvector<8xf32>, #map0>, memref<2xi32>
dma_wait %5[%c0], %num_elts : memref<2xi32>
} // CHECK }
return
}