mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 02:00:03 +08:00
Generalize / improve DMA transfer overlap; nested and multiple DMA support; resolve
multiple TODOs. - replace the fake test pass (that worked on just the first loop in the MLFunction) to perform DMA pipelining on all suitable loops. - nested DMAs work now (DMAs in an outer loop, more DMAs in nested inner loops) - fix bugs / assumptions: correctly copy memory space and elemental type of source memref for double buffering. - correctly identify matching start/finish statements, handle multiple DMAs per loop. - introduce dominates/properlyDominates utitilies for MLFunction statements. - move checkDominancePreservationOnShifts to LoopAnalysis.h; rename it getShiftValidity - refactor getContainingStmtPos -> findAncestorStmtInBlock - move into Analysis/Utils.h; has two users. - other improvements / cleanup for related API/utilities - add size argument to dma_wait - for nested DMAs or in general, it makes it easy to obtain the size to use when lowering the dma_wait since we wouldn't want to identify the matching dma_start, and more importantly, in general/in the future, there may not always be a dma_start dominating the dma_wait. - add debug information in the pass PiperOrigin-RevId: 217734892
This commit is contained in:
committed by
jpienaar
parent
3013dadb7c
commit
18e666702c
@@ -22,6 +22,7 @@
|
||||
#ifndef MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||
#define MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
namespace mlir {
|
||||
@@ -52,6 +53,13 @@ uint64_t getLargestDivisorOfTripCount(const ForStmt &forStmt);
|
||||
// vectorizable. A function over the actions will give us a cost model.
|
||||
bool isVectorizableLoop(const ForStmt &loop);
|
||||
|
||||
/// Checks where SSA dominance would be violated if a for stmt's body statements
|
||||
/// are shifted by the specified shifts. This method checks if a 'def' and all
|
||||
/// its uses have the same shift factor.
|
||||
// TODO(mlir-team): extend this to check for memory-based dependence
|
||||
// violation when we have the support.
|
||||
bool isStmtwiseShiftValid(const ForStmt &forStmt,
|
||||
llvm::ArrayRef<uint64_t> shifts);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_LOOP_ANALYSIS_H
|
||||
|
||||
40
mlir/include/mlir/Analysis/Utils.h
Normal file
40
mlir/include/mlir/Analysis/Utils.h
Normal file
@@ -0,0 +1,40 @@
|
||||
//===- Utils.h - General analysis utilities ---------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This header file defines prototypes for various transformation utilities for
|
||||
// memref's and non-loop IR structures. These are not passes by themselves but
|
||||
// are used either by passes, optimization sequences, or in turn by other
|
||||
// transformation utilities.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_ANALYSIS_UTILS_H
|
||||
#define MLIR_ANALYSIS_UTILS_H
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class Statement;
|
||||
|
||||
/// Returns true if statement 'a' dominates statement b.
|
||||
bool dominates(const Statement &a, const Statement &b);
|
||||
|
||||
/// Returns true if statement 'a' properly dominates statement b.
|
||||
bool properlyDominates(const Statement &a, const Statement &b);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_ANALYSIS_UTILS_H
|
||||
@@ -105,6 +105,26 @@ public:
|
||||
void printBlock(raw_ostream &os) const;
|
||||
void dumpBlock() const;
|
||||
|
||||
/// Returns the statement's position in this block or -1 if the statement is
|
||||
/// not present.
|
||||
int findStmtPosInBlock(const Statement &stmt) const {
|
||||
unsigned j = 0;
|
||||
for (const auto &s : statements) {
|
||||
if (&s == &stmt)
|
||||
return j;
|
||||
j++;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the
|
||||
/// ancestor statement of 'stmt' that lies in this block. Returns nullptr if
|
||||
/// the latter fails.
|
||||
const Statement *findAncestorStmtInBlock(const Statement &stmt) const;
|
||||
Statement *findAncestorStmtInBlock(Statement *stmt) {
|
||||
return const_cast<Statement *>(findAncestorStmtInBlock(*stmt));
|
||||
}
|
||||
|
||||
protected:
|
||||
StmtBlock(StmtBlockKind kind) : kind(kind) {}
|
||||
|
||||
|
||||
@@ -310,6 +310,25 @@ public:
|
||||
getOperation()->operand_end()};
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a faster memory space to a slower one.
|
||||
bool isDestMemorySpaceFaster() const {
|
||||
return (getSrcMemorySpace() < getDstMemorySpace());
|
||||
}
|
||||
|
||||
/// Returns true if this is a DMA from a slower memory space to a faster one.
|
||||
bool isSrcMemorySpaceFaster() const {
|
||||
// Assumes that a lower number is for a slower memory space.
|
||||
return (getDstMemorySpace() < getSrcMemorySpace());
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy. Asserts failure if neither is true.
|
||||
unsigned getFasterMemPos() const {
|
||||
assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
|
||||
return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1;
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "dma_start"; }
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
@@ -321,8 +340,9 @@ protected:
|
||||
|
||||
// DmaWaitOp blocks until the completion of a DMA operation associated with the
|
||||
// tag element '%tag[%index]'. %tag is a memref, and %index has to be an index
|
||||
// with the same restrictions as any load/store index in MLFunctions. For
|
||||
// example:
|
||||
// with the same restrictions as any load/store index in MLFunctions.
|
||||
// %num_elements is the number of elements associated with the DMA operation.
|
||||
// For example:
|
||||
//
|
||||
// dma_start %src[%i, %j], %dst[%k, %l], %tag[%index] :
|
||||
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
|
||||
@@ -330,7 +350,7 @@ protected:
|
||||
// memref<1 x i32>, (d0) -> (d0), 4>
|
||||
// ...
|
||||
// ...
|
||||
// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
|
||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
|
||||
//
|
||||
class DmaWaitOp
|
||||
: public Op<DmaWaitOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
@@ -344,7 +364,18 @@ public:
|
||||
// Returns the tag memref index for this DMA operation.
|
||||
llvm::iterator_range<Operation::const_operand_iterator>
|
||||
getTagIndices() const {
|
||||
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
|
||||
return {getOperation()->operand_begin() + 1,
|
||||
getOperation()->operand_begin() + 1 + getTagMemRefRank()};
|
||||
}
|
||||
|
||||
// Returns the rank (number of indices) of the tag memref.
|
||||
unsigned getTagMemRefRank() const {
|
||||
return cast<MemRefType>(getTagMemRef()->getType())->getRank();
|
||||
}
|
||||
|
||||
// Returns the number of elements transferred in the associated DMA operation.
|
||||
const SSAValue *getNumElements() const {
|
||||
return getOperand(1 + getTagMemRefRank());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
@@ -86,10 +86,6 @@ AffineMap getUnrolledLoopUpperBound(const ForStmt &forStmt,
|
||||
UtilResult stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
|
||||
bool unrollPrologueEpilogue = false);
|
||||
|
||||
/// Checks if SSA dominance would be violated if a for stmt's child statements
|
||||
/// are shifted by the specified delays.
|
||||
bool checkDominancePreservationOnShift(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays);
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class SSAValue;
|
||||
/// Additional indices are added at the start.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
|
||||
llvm::ArrayRef<MLValue *> extraIndices,
|
||||
AffineMap indexRemap = AffineMap::Invalid());
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/AffineStructures.h"
|
||||
#include "mlir/Analysis/MLFunctionMatcher.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Statements.h"
|
||||
@@ -128,12 +126,13 @@ static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
|
||||
assert(indices.size() == memRefType->getRank());
|
||||
assert(dim < indices.size());
|
||||
auto layoutMap = memRefType->getAffineMaps();
|
||||
assert(layoutMap.size() <= 1);
|
||||
assert(memRefType->getAffineMaps().size() <= 1);
|
||||
// TODO(ntv): remove dependency on Builder once we support non-identity
|
||||
// layout map.
|
||||
Builder b(memRefType->getContext());
|
||||
assert(layoutMap.empty() ||
|
||||
layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
|
||||
(void)layoutMap;
|
||||
|
||||
SmallVector<OperationStmt *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
|
||||
@@ -197,3 +196,35 @@ bool mlir::isVectorizableLoop(const ForStmt &loop) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Checks whether SSA dominance would be violated if a for stmt's body
|
||||
/// statements are shifted by the specified shifts. This method checks if a
|
||||
/// 'def' and all its uses have the same shift factor.
|
||||
// TODO(mlir-team): extend this to check for memory-based dependence
|
||||
// violation when we have the support.
|
||||
bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> shifts) {
|
||||
assert(shifts.size() == forStmt.getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : forStmt) {
|
||||
// A for or if stmt does not produce any def/results (that are used
|
||||
// outside).
|
||||
if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
|
||||
const MLValue *result = opStmt->getResult(i);
|
||||
for (const StmtOperand &use : result->getUses()) {
|
||||
// If an ancestor statement doesn't lie in the block of forStmt, there
|
||||
// is no shift to check.
|
||||
// This is a naive way. If performance becomes an issue, a map can
|
||||
// be used to store 'shifts' - to look up the shift for a statement in
|
||||
// constant time.
|
||||
if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
|
||||
if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
s++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
62
mlir/lib/Analysis/Utils.cpp
Normal file
62
mlir/lib/Analysis/Utils.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
//===- Utils.cpp ---- Misc utilities for analysis -------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file implements miscellaneous analysis routines for non-loop IR
|
||||
// structures.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
|
||||
#include "mlir/IR/Statements.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Returns true if statement 'a' properly dominates statement b.
|
||||
bool mlir::properlyDominates(const Statement &a, const Statement &b) {
|
||||
if (&a == &b)
|
||||
return false;
|
||||
|
||||
if (a.findFunction() != b.findFunction())
|
||||
return false;
|
||||
|
||||
if (a.getBlock() == b.getBlock()) {
|
||||
// Do a linear scan to determine whether b comes after a.
|
||||
auto aIter = StmtBlock::const_iterator(a);
|
||||
auto bIter = StmtBlock::const_iterator(b);
|
||||
auto aBlockStart = a.getBlock()->begin();
|
||||
while (bIter != aBlockStart) {
|
||||
--bIter;
|
||||
if (aIter == bIter)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Traverse up b's hierarchy to check if b's block is contained in a's.
|
||||
if (const auto *bAncestor = a.getBlock()->findAncestorStmtInBlock(b))
|
||||
// a and bAncestor are in the same block; check if the former dominates it.
|
||||
return dominates(a, *bAncestor);
|
||||
|
||||
// b's block is not contained in A.
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Returns true if statement A dominates statement B.
|
||||
bool mlir::dominates(const Statement &a, const Statement &b) {
|
||||
return &a == &b || properlyDominates(a, b);
|
||||
}
|
||||
@@ -45,3 +45,19 @@ MLFunction *StmtBlock::findFunction() const {
|
||||
}
|
||||
return dyn_cast<MLFunction>(block);
|
||||
}
|
||||
|
||||
/// Returns 'stmt' if 'stmt' lies in this block, or otherwise finds the ancestor
|
||||
/// statement of 'stmt' that lies in this block. Returns nullptr if the latter
|
||||
/// fails.
|
||||
const Statement *
|
||||
StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const {
|
||||
// Traverse up the statement hierarchy starting from the owner of operand to
|
||||
// find the ancestor statement that resides in the block of 'forStmt'.
|
||||
const auto *currStmt = &stmt;
|
||||
while (currStmt->getBlock() != this) {
|
||||
currStmt = currStmt->getParentStmt();
|
||||
if (!currStmt)
|
||||
return nullptr;
|
||||
}
|
||||
return currStmt;
|
||||
}
|
||||
|
||||
@@ -392,7 +392,7 @@ void DmaStartOp::print(OpAsmPrinter *p) const {
|
||||
}
|
||||
|
||||
// Parse DmaStartOp.
|
||||
// EX:
|
||||
// Ex:
|
||||
// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
|
||||
// %tag[%index] :
|
||||
// memref<3 x vector<8x128xf32>, (d0) -> (d0), 0>,
|
||||
@@ -458,33 +458,38 @@ bool DmaStartOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
// ---------------------------------------------------------------------------
|
||||
// DmaWaitOp
|
||||
// ---------------------------------------------------------------------------
|
||||
// Parse DmaWaitOp.
|
||||
// Eg:
|
||||
// dma_wait %tag[%index] : memref<1 x i32, (d0) -> (d0), 4>
|
||||
//
|
||||
|
||||
void DmaWaitOp::print(OpAsmPrinter *p) const {
|
||||
*p << getOperationName() << ' ';
|
||||
// Print operands.
|
||||
p->printOperand(getTagMemRef());
|
||||
*p << '[';
|
||||
p->printOperands(getTagIndices());
|
||||
*p << ']';
|
||||
*p << "], ";
|
||||
p->printOperand(getNumElements());
|
||||
*p << " : " << *getTagMemRef()->getType();
|
||||
}
|
||||
|
||||
// Parse DmaWaitOp.
|
||||
// Eg:
|
||||
// dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4>
|
||||
//
|
||||
bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType tagMemrefInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
|
||||
Type *type;
|
||||
auto *indexType = parser->getBuilder().getIndexType();
|
||||
OpAsmParser::OperandType numElementsInfo;
|
||||
|
||||
// Parse tag memref and index.
|
||||
// Parse tag memref, its indices, and dma size.
|
||||
if (parser->parseOperand(tagMemrefInfo) ||
|
||||
parser->parseOperandList(tagIndexInfos, -1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseComma() || parser->parseOperand(numElementsInfo) ||
|
||||
parser->parseColonType(type) ||
|
||||
parser->resolveOperand(tagMemrefInfo, type, result->operands) ||
|
||||
parser->resolveOperands(tagIndexInfos, indexType, result->operands))
|
||||
parser->resolveOperands(tagIndexInfos, indexType, result->operands) ||
|
||||
parser->resolveOperand(numElementsInfo, indexType, result->operands))
|
||||
return true;
|
||||
|
||||
if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
|
||||
|
||||
@@ -181,57 +181,6 @@ generateLoop(AffineMap lb, AffineMap ub,
|
||||
return loopChunk;
|
||||
}
|
||||
|
||||
// Returns delay of that child statement of 'forStmt' which either has 'operand'
|
||||
// as one of its operands or has a descendant statement with operand 'operand'.
|
||||
// This is a naive implementation. If performance becomes an issue, a map can
|
||||
// be used to store 'delays' - to look up the delay for a statement in constant
|
||||
// time.
|
||||
static uint64_t getContainingStmtDelay(const StmtOperand &operand,
|
||||
const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays) {
|
||||
// Traverse up the statement hierarchy starting from the owner of operand to
|
||||
// find the ancestor statement that resides in the block of 'forStmt'.
|
||||
const Statement *stmt = operand.getOwner();
|
||||
assert(stmt != nullptr);
|
||||
while (stmt->getParentStmt() != &forStmt) {
|
||||
stmt = stmt->getParentStmt();
|
||||
assert(stmt && "traversing parent's should reach forStmt block");
|
||||
}
|
||||
// Look up the delay of 'stmt'.
|
||||
unsigned j = 0;
|
||||
for (const auto &s : forStmt) {
|
||||
if (&s == stmt)
|
||||
break;
|
||||
j++;
|
||||
}
|
||||
assert(j < forStmt.getStatements().size() && "child stmt should be found");
|
||||
return delays[j];
|
||||
}
|
||||
|
||||
/// Checks if SSA dominance would be violated if a for stmt's body statements
|
||||
/// are shifted by the specified delays. This method checks if a 'def' and all
|
||||
/// its uses have the same delay factor.
|
||||
bool mlir::checkDominancePreservationOnShift(const ForStmt &forStmt,
|
||||
ArrayRef<uint64_t> delays) {
|
||||
assert(delays.size() == forStmt.getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : forStmt) {
|
||||
// A for or if stmt does not produce any def/results (that are used
|
||||
// outside).
|
||||
if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
|
||||
for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
|
||||
const MLValue *result = opStmt->getResult(i);
|
||||
for (const StmtOperand &use : result->getUses()) {
|
||||
if (delays[s] != getContainingStmtDelay(use, forStmt, delays))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
s++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Skew the statements in the body of a 'for' statement with the specified
|
||||
/// statement-wise delays. The delays are with respect to the original execution
|
||||
/// order. A delay of zero for each statement will lead to no change.
|
||||
@@ -260,7 +209,7 @@ UtilResult mlir::stmtBodySkew(ForStmt *forStmt, ArrayRef<uint64_t> delays,
|
||||
return UtilResult::Failure;
|
||||
uint64_t tripCount = mayBeConstTripCount.getValue();
|
||||
|
||||
assert(checkDominancePreservationOnShift(*forStmt, delays) &&
|
||||
assert(isStmtwiseShiftValid(*forStmt, delays) &&
|
||||
"dominance preservation failed\n");
|
||||
|
||||
unsigned numChildStmts = forStmt->getStatements().size();
|
||||
|
||||
@@ -22,21 +22,31 @@
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/LoopAnalysis.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "mlir/Transforms/Pass.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define DEBUG_TYPE "pipeline-data-transfer"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct PipelineDataTransfer : public MLFunctionPass {
|
||||
explicit PipelineDataTransfer() {}
|
||||
struct PipelineDataTransfer : public MLFunctionPass,
|
||||
StmtWalker<PipelineDataTransfer> {
|
||||
PassResult runOnMLFunction(MLFunction *f) override;
|
||||
PassResult runOnForStmt(ForStmt *forStmt);
|
||||
|
||||
// Collect all 'for' statements.
|
||||
void visitForStmt(ForStmt *forStmt) { forStmts.push_back(forStmt); }
|
||||
std::vector<ForStmt *> forStmts;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
@@ -47,20 +57,6 @@ MLFunctionPass *mlir::createPipelineDataTransferPass() {
|
||||
return new PipelineDataTransfer();
|
||||
}
|
||||
|
||||
/// Given a DMA start operation, returns the operand position of either the
|
||||
/// source or destination memref depending on the one that is at the higher
|
||||
/// level of the memory hierarchy.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
static unsigned getHigherMemRefPos(OpPointer<DmaStartOp> dmaStartOp) {
|
||||
unsigned srcDmaPos = 0;
|
||||
unsigned destDmaPos = dmaStartOp->getSrcMemRefRank() + 1;
|
||||
|
||||
if (dmaStartOp->getSrcMemorySpace() > dmaStartOp->getDstMemorySpace())
|
||||
return srcDmaPos;
|
||||
return destDmaPos;
|
||||
}
|
||||
|
||||
// Returns the position of the tag memref operand given a DMA statement.
|
||||
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
|
||||
// added. TODO(b/117228571)
|
||||
@@ -76,18 +72,20 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
|
||||
|
||||
/// Doubles the buffer of the supplied memref while replacing all uses of the
|
||||
/// old memref. Returns false if such a replacement cannot be performed.
|
||||
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
MLFuncBuilder bInner(forStmt, forStmt->begin());
|
||||
bInner.setInsertionPoint(forStmt, forStmt->begin());
|
||||
|
||||
// Doubles the shape with a leading dimension extent of 2.
|
||||
auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
|
||||
auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
|
||||
// Add the leading dimension in the shape for the double buffer.
|
||||
ArrayRef<int> shape = origMemRefType->getShape();
|
||||
ArrayRef<int> shape = oldMemRefType->getShape();
|
||||
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
|
||||
shapeSizes.insert(shapeSizes.begin(), 2);
|
||||
|
||||
auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
|
||||
auto *newMemRefType =
|
||||
bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
|
||||
oldMemRefType->getMemorySpace());
|
||||
return newMemRefType;
|
||||
};
|
||||
|
||||
@@ -105,113 +103,187 @@ static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
|
||||
auto ivModTwoOp =
|
||||
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
|
||||
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
|
||||
cast<MLValue>(ivModTwoOp->getResult(0))))
|
||||
cast<MLValue>(ivModTwoOp->getResult(0)))) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "memref replacement for double buffering failed\n";);
|
||||
cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// For testing purposes, this just runs on the first 'for' statement of an
|
||||
// MLFunction at the top level.
|
||||
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
|
||||
// the other TODOs listed inside are dealt with.
|
||||
/// Returns false if this succeeds on at least one 'for' stmt.
|
||||
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
|
||||
if (f->empty())
|
||||
return PassResult::Success;
|
||||
|
||||
ForStmt *forStmt = nullptr;
|
||||
for (auto &stmt : *f) {
|
||||
if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
|
||||
break;
|
||||
}
|
||||
// Do a post order walk so that inner loop DMAs are processed first. This is
|
||||
// necessary since 'for' statements nested within would otherwise become
|
||||
// invalid (erased) when the outer loop is pipelined (the pipelined one gets
|
||||
// deleted and replaced by a prologue, a new steady-state loop and an
|
||||
// epilogue).
|
||||
forStmts.clear();
|
||||
walkPostOrder(f);
|
||||
bool ret = true;
|
||||
for (auto *forStmt : forStmts) {
|
||||
ret = ret & runOnForStmt(forStmt);
|
||||
}
|
||||
if (!forStmt)
|
||||
return PassResult::Success;
|
||||
return ret ? failure() : success();
|
||||
}
|
||||
|
||||
unsigned numStmts = forStmt->getStatements().size();
|
||||
// Check if tags of the dma start op and dma wait op match.
|
||||
static bool checkTagMatch(OpPointer<DmaStartOp> startOp,
|
||||
OpPointer<DmaWaitOp> waitOp) {
|
||||
if (startOp->getTagMemRef() != waitOp->getTagMemRef())
|
||||
return false;
|
||||
auto startIndices = startOp->getTagIndices();
|
||||
auto waitIndices = waitOp->getTagIndices();
|
||||
// Both of these have the same number of indices since they correspond to the
|
||||
// same tag memref.
|
||||
for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
|
||||
e = startIndices.end();
|
||||
it != e; ++it, ++wIt) {
|
||||
// Keep it simple for now, just checking if indices match.
|
||||
// TODO(mlir-team): this would in general need to check if there is no
|
||||
// intervening write writing to the same tag location, i.e., memory last
|
||||
// write/data flow analysis. This is however sufficient/powerful enough for
|
||||
// now since the DMA generation pass or the input for it will always have
|
||||
// start/wait with matching tags (same SSA operand indices).
|
||||
if (*it != *wIt)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (numStmts == 0)
|
||||
return PassResult::Success;
|
||||
|
||||
SmallVector<OperationStmt *, 4> dmaStartStmts;
|
||||
SmallVector<OperationStmt *, 4> dmaFinishStmts;
|
||||
// Identify matching DMA start/finish statements to overlap computation with.
|
||||
static void findMatchingStartFinishStmts(
|
||||
ForStmt *forStmt,
|
||||
SmallVectorImpl<std::pair<OperationStmt *, OperationStmt *>>
|
||||
&startWaitPairs) {
|
||||
SmallVector<OperationStmt *, 4> dmaStartStmts, dmaFinishStmts;
|
||||
for (auto &stmt : *forStmt) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt)
|
||||
continue;
|
||||
if (opStmt->is<DmaStartOp>()) {
|
||||
dmaStartStmts.push_back(opStmt);
|
||||
} else if (opStmt->is<DmaWaitOp>()) {
|
||||
// Collect DMA finish statements.
|
||||
if (opStmt->is<DmaWaitOp>()) {
|
||||
dmaFinishStmts.push_back(opStmt);
|
||||
continue;
|
||||
}
|
||||
OpPointer<DmaStartOp> dmaStartOp;
|
||||
if (!(dmaStartOp = opStmt->getAs<DmaStartOp>()))
|
||||
continue;
|
||||
// Only DMAs incoming into higher memory spaces.
|
||||
// TODO(bondhugula): outgoing DMAs.
|
||||
if (!dmaStartOp->isDestMemorySpaceFaster())
|
||||
continue;
|
||||
|
||||
// We only double buffer if the buffer is not live out of loop.
|
||||
const MLValue *memref =
|
||||
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
|
||||
bool escapingUses = false;
|
||||
for (const auto &use : memref->getUses()) {
|
||||
if (!dominates(*forStmt, *use.getOwner())) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "can't pipeline: buffer is live out of loop\n";);
|
||||
escapingUses = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!escapingUses)
|
||||
dmaStartStmts.push_back(opStmt);
|
||||
}
|
||||
|
||||
// TODO(bondhugula,andydavis): match tag memref's (requires memory-based
|
||||
// subscript check utilities). Assume for now that start/finish are matched in
|
||||
// the order they appear.
|
||||
if (dmaStartStmts.size() != dmaFinishStmts.size())
|
||||
// For each start statement, we look for a matching finish statement.
|
||||
for (auto *dmaStartStmt : dmaStartStmts) {
|
||||
for (auto *dmaFinishStmt : dmaFinishStmts) {
|
||||
if (checkTagMatch(dmaStartStmt->getAs<DmaStartOp>(),
|
||||
dmaFinishStmt->getAs<DmaWaitOp>())) {
|
||||
startWaitPairs.push_back({dmaStartStmt, dmaFinishStmt});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Overlap DMA transfers with computation in this loop. If successful,
|
||||
/// 'forStmt' is deleted, and a prologue, a new pipelined loop, and epilogue are
|
||||
/// inserted right before where it was.
|
||||
PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
|
||||
auto mayBeConstTripCount = getConstantTripCount(*forStmt);
|
||||
if (!mayBeConstTripCount.hasValue()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "unknown trip count loop\n");
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
SmallVector<std::pair<OperationStmt *, OperationStmt *>, 4> startWaitPairs;
|
||||
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
||||
|
||||
if (startWaitPairs.empty()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "No dma start/finish pairs\n";);
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Double the buffers for the higher memory space memref's.
|
||||
// TODO(bondhugula): assuming we don't have multiple DMA starts for the same
|
||||
// memref.
|
||||
// Identify memref's to replace by scanning through all DMA start statements.
|
||||
// A DMA start statement has two memref's - the one from the higher level of
|
||||
// memory hierarchy is the one to double buffer.
|
||||
// TODO(bondhugula): check whether double-buffering is even necessary.
|
||||
// TODO(bondhugula): make this work with different layouts: assuming here that
|
||||
// the dimension we are adding here for the double buffering is the outermost
|
||||
// dimension.
|
||||
// Identify memref's to replace by scanning through all DMA start statements.
|
||||
// A DMA start statement has two memref's - the one from the higher level of
|
||||
// memory hierarchy is the one to double buffer.
|
||||
for (auto *dmaStartStmt : dmaStartStmts) {
|
||||
MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
||||
getHigherMemRefPos(dmaStartStmt->getAs<DmaStartOp>())));
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
const MLValue *oldMemRef = cast<MLValue>(dmaStartStmt->getOperand(
|
||||
dmaStartStmt->getAs<DmaStartOp>()->getFasterMemPos()));
|
||||
if (!doubleBuffer(oldMemRef, forStmt)) {
|
||||
return PassResult::Failure;
|
||||
// Normally, double buffering should not fail because we already checked
|
||||
// that there are no uses outside.
|
||||
LLVM_DEBUG(llvm::dbgs() << "double buffering failed for: \n";);
|
||||
LLVM_DEBUG(dmaStartStmt->dump());
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Double the buffers for tag memref's.
|
||||
for (auto *dmaFinishStmt : dmaFinishStmts) {
|
||||
MLValue *oldTagMemRef = cast<MLValue>(
|
||||
// Double the buffers for tag memrefs.
|
||||
for (auto &pair : startWaitPairs) {
|
||||
const auto *dmaFinishStmt = pair.second;
|
||||
const MLValue *oldTagMemRef = cast<MLValue>(
|
||||
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
|
||||
if (!doubleBuffer(oldTagMemRef, forStmt)) {
|
||||
return PassResult::Failure;
|
||||
LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all compute ops.
|
||||
std::vector<const Statement *> computeOps;
|
||||
computeOps.reserve(forStmt->getStatements().size());
|
||||
// Double buffering would have invalidated all the old DMA start/wait stmts.
|
||||
startWaitPairs.clear();
|
||||
findMatchingStartFinishStmts(forStmt, startWaitPairs);
|
||||
|
||||
// Store delay for statement for later lookup for AffineApplyOp's.
|
||||
DenseMap<const Statement *, unsigned> opDelayMap;
|
||||
for (auto &stmt : *forStmt) {
|
||||
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
|
||||
if (!opStmt) {
|
||||
// All for and if stmt's are treated as pure compute operations.
|
||||
opDelayMap[&stmt] = 1;
|
||||
} else if (opStmt->is<DmaStartOp>()) {
|
||||
// DMA starts are not shifted.
|
||||
opDelayMap[opStmt] = 0;
|
||||
// Set shifts for DMA start stmt's affine operand computation slices to 0.
|
||||
if (auto *slice = mlir::createAffineComputationSlice(opStmt)) {
|
||||
opDelayMap[slice] = 0;
|
||||
} else {
|
||||
// If a slice wasn't created, the reachable affine_apply op's from its
|
||||
// operands are the ones that go with it.
|
||||
SmallVector<OperationStmt *, 4> affineApplyStmts;
|
||||
SmallVector<MLValue *, 4> operands(opStmt->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyStmts);
|
||||
for (auto *op : affineApplyStmts) {
|
||||
opDelayMap[op] = 0;
|
||||
}
|
||||
}
|
||||
} else if (opStmt->is<DmaWaitOp>()) {
|
||||
// DMA finish op shifted by one.
|
||||
opDelayMap[opStmt] = 1;
|
||||
DenseMap<const Statement *, unsigned> stmtDelayMap;
|
||||
for (auto &pair : startWaitPairs) {
|
||||
auto *dmaStartStmt = pair.first;
|
||||
assert(dmaStartStmt->is<DmaStartOp>());
|
||||
stmtDelayMap[dmaStartStmt] = 0;
|
||||
// Set shifts for DMA start stmt's affine operand computation slices to 0.
|
||||
if (auto *slice = mlir::createAffineComputationSlice(dmaStartStmt)) {
|
||||
stmtDelayMap[slice] = 0;
|
||||
} else {
|
||||
// Everything else is a compute op; so shifted by one (op's supplying
|
||||
// 'affine' operands to DMA start's have already been set right shifts.
|
||||
opDelayMap[opStmt] = 1;
|
||||
computeOps.push_back(&stmt);
|
||||
// If a slice wasn't created, the reachable affine_apply op's from its
|
||||
// operands are the ones that go with it.
|
||||
SmallVector<OperationStmt *, 4> affineApplyStmts;
|
||||
SmallVector<MLValue *, 4> operands(dmaStartStmt->getOperands());
|
||||
getReachableAffineApplyOps(operands, affineApplyStmts);
|
||||
for (const auto *stmt : affineApplyStmts) {
|
||||
stmtDelayMap[stmt] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Everything else (including compute ops and dma finish) are shifted by one.
|
||||
for (const auto &stmt : *forStmt) {
|
||||
if (stmtDelayMap.find(&stmt) == stmtDelayMap.end()) {
|
||||
stmtDelayMap[&stmt] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -219,18 +291,20 @@ PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
|
||||
std::vector<uint64_t> delays(forStmt->getStatements().size());
|
||||
unsigned s = 0;
|
||||
for (const auto &stmt : *forStmt) {
|
||||
assert(opDelayMap.find(&stmt) != opDelayMap.end());
|
||||
delays[s++] = opDelayMap[&stmt];
|
||||
assert(stmtDelayMap.find(&stmt) != stmtDelayMap.end());
|
||||
delays[s++] = stmtDelayMap[&stmt];
|
||||
}
|
||||
|
||||
if (!checkDominancePreservationOnShift(*forStmt, delays)) {
|
||||
if (!isStmtwiseShiftValid(*forStmt, delays)) {
|
||||
// Violates SSA dominance.
|
||||
LLVM_DEBUG(llvm::dbgs() << "Dominance check failed\n";);
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
if (stmtBodySkew(forStmt, delays)) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "stmt body skewing failed\n";);
|
||||
return PassResult::Failure;
|
||||
}
|
||||
|
||||
return PassResult::Success;
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -48,7 +48,8 @@ static bool isMemRefDereferencingOp(const Operation &op) {
|
||||
/// at the start for now.
|
||||
// TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
|
||||
// extended to add additional indices at any position.
|
||||
bool mlir::replaceAllMemRefUsesWith(MLValue *oldMemRef, MLValue *newMemRef,
|
||||
bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
|
||||
MLValue *newMemRef,
|
||||
ArrayRef<MLValue *> extraIndices,
|
||||
AffineMap indexRemap) {
|
||||
unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
|
||||
@@ -219,11 +220,11 @@ mlir::createComposedAffineApplyOp(MLFuncBuilder *builder, Location *loc,
|
||||
/// This allows applying different transformations on send and compute (for eg.
|
||||
/// different shifts/delays).
|
||||
///
|
||||
/// Returns nullptr if none of the operands were the result of an affine_apply
|
||||
/// and thus there was no affine computation slice to create. Returns the newly
|
||||
/// affine_apply operation statement otherwise.
|
||||
///
|
||||
///
|
||||
/// Returns nullptr either if none of opStmt's operands were the result of an
|
||||
/// affine_apply and thus there was no affine computation slice to create, or if
|
||||
/// all the affine_apply op's supplying operands to this opStmt do not have any
|
||||
/// uses besides this opStmt. Returns the new affine_apply operation statement
|
||||
/// otherwise.
|
||||
OperationStmt *mlir::createAffineComputationSlice(OperationStmt *opStmt) {
|
||||
// Collect all operands that are results of affine apply ops.
|
||||
SmallVector<MLValue *, 4> subOperands;
|
||||
|
||||
@@ -1,30 +1,28 @@
|
||||
// RUN: mlir-opt %s -pipeline-data-transfer | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = (d0) -> (d0 mod 2, d0 mod 2)
|
||||
// CHECK-NEXT: #map1 = (d0) -> (d0 - 1)
|
||||
// CHECK-NEXT: #map2 = (d0) -> (d0 mod 2)
|
||||
// CHECK-NEXT: mlfunc @loop_nest_dma() {
|
||||
// CHECK-NEXT: %c8 = constant 8 : index
|
||||
// CHECK-LABEL: mlfunc @loop_nest_dma() {
|
||||
mlfunc @loop_nest_dma() {
|
||||
// CHECK: %c8 = constant 8 : index
|
||||
// CHECK-NEXT: %c0 = constant 0 : index
|
||||
// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
|
||||
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32>
|
||||
// CHECK-NEXT: %1 = alloc() : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %2 = alloc() : memref<256xf32, (d0) -> (d0)>
|
||||
// CHECK-NEXT: %3 = alloc() : memref<32xf32, (d0) -> (d0), 1>
|
||||
// CHECK-NEXT: %4 = alloc() : memref<1xf32>
|
||||
// CHECK-NEXT: %c0_0 = constant 0 : index
|
||||
// CHECK-NEXT: %c128 = constant 128 : index
|
||||
// CHECK-NEXT: %5 = affine_apply #map0(%c0)
|
||||
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
|
||||
// CHECK-NEXT: dma_start %2[%c0], %1[%5#0, %c0], %c128, %0[%5#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
// CHECK-NEXT: %6 = affine_apply #map0(%i0)
|
||||
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32>, memref<2x1xf32>
|
||||
// CHECK-NEXT: dma_start %2[%i0], %1[%6#0, %i0], %c128, %0[%6#1, %c0_0] : memref<256xf32, (d0) -> (d0)>, memref<2x32xf32, 1>, memref<2x1xf32>
|
||||
// CHECK-NEXT: %7 = affine_apply #map1(%i0)
|
||||
// CHECK-NEXT: %8 = affine_apply #map2(%7)
|
||||
// CHECK-NEXT: %9 = affine_apply #map2(%7)
|
||||
// CHECK-NEXT: dma_wait %0[%8, %c0_0] : memref<2x1xf32>
|
||||
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32>
|
||||
// CHECK-NEXT: dma_wait %0[%8, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %10 = load %1[%9, %7] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %11 = "compute"(%10) : (f32) -> f32
|
||||
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32>
|
||||
// CHECK-NEXT: store %11, %1[%9, %7] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: for %i1 = 0 to 127 {
|
||||
// CHECK-NEXT: "do_more_compute"(%7, %i1) : (index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
@@ -32,15 +30,14 @@
|
||||
// CHECK-NEXT: %12 = affine_apply #map1(%c8)
|
||||
// CHECK-NEXT: %13 = affine_apply #map2(%12)
|
||||
// CHECK-NEXT: %14 = affine_apply #map2(%12)
|
||||
// CHECK-NEXT: dma_wait %0[%13, %c0_0] : memref<2x1xf32>
|
||||
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32>
|
||||
// CHECK-NEXT: dma_wait %0[%13, %c0_0], %c128 : memref<2x1xf32>
|
||||
// CHECK-NEXT: %15 = load %1[%14, %12] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: %16 = "compute"(%15) : (f32) -> f32
|
||||
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32>
|
||||
// CHECK-NEXT: store %16, %1[%14, %12] : memref<2x32xf32, 1>
|
||||
// CHECK-NEXT: for %i2 = 0 to 127 {
|
||||
// CHECK-NEXT: "do_more_compute"(%12, %i2) : (index, index) -> ()
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return
|
||||
mlfunc @loop_nest_dma() {
|
||||
|
||||
%A = alloc() : memref<256 x f32, (d0) -> (d0), 0>
|
||||
%Ah = alloc() : memref<32 x f32, (d0) -> (d0), 1>
|
||||
@@ -48,11 +45,11 @@ mlfunc @loop_nest_dma() {
|
||||
%tag = alloc() : memref<1 x f32>
|
||||
|
||||
%zero = constant 0 : index
|
||||
%size = constant 128 : index
|
||||
%num_elts = constant 128 : index
|
||||
|
||||
for %i = 0 to 7 {
|
||||
dma_start %A[%i], %Ah[%i], %size, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32>
|
||||
dma_wait %tag[%zero] : memref<1 x f32>
|
||||
dma_start %A[%i], %Ah[%i], %num_elts, %tag[%zero] : memref<256 x f32, (d0)->(d0), 0>, memref<32 x f32, (d0)->(d0), 1>, memref<1 x f32>
|
||||
dma_wait %tag[%zero], %num_elts : memref<1 x f32>
|
||||
%v = load %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
|
||||
%r = "compute"(%v) : (f32) -> (f32)
|
||||
store %r, %Ah[%i] : memref<32 x f32, (d0) -> (d0), 1>
|
||||
@@ -62,3 +59,76 @@ mlfunc @loop_nest_dma() {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
#map0 = (d0, d1) -> (d0, d1)
|
||||
#map1 = (d0, d1) -> ((d0 * 2048 + d1 * 256) floordiv 32, 0)
|
||||
#map2 = (d0) -> ((d0 * 2048) floordiv 32, 0)
|
||||
// CHECK: mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>
|
||||
mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : memref<512x32xvector<8xf32>, #map0>, %arg2 : memref<512x32xvector<8xf32>, #map0>) {
|
||||
%num_elts = constant 256 : index
|
||||
%c0 = constant 0 : index
|
||||
%0 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%1 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%2 = alloc() : memref<64x4xvector<8xf32>, #map0, 2>
|
||||
%3 = alloc() : memref<2xi32>
|
||||
%4 = alloc() : memref<2xi32>
|
||||
%5 = alloc() : memref<2xi32>
|
||||
// Prologue for DMA overlap on arg2.
|
||||
// CHECK: dma_start %arg2[
|
||||
// CHECK-NEXT: for %i0 = 1 to 7 {
|
||||
for %i0 = 0 to 7 {
|
||||
%6 = affine_apply #map2(%i0)
|
||||
dma_start %arg2[%6#0, %6#1], %2[%c0, %c0], %num_elts, %5[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
dma_wait %5[%c0], %num_elts : memref<2xi32>
|
||||
// Steady state for DMA overlap on arg2
|
||||
// CHECK: dma_start %arg2[
|
||||
// CHECK: dma_wait %0[
|
||||
// Prologue for DMA overlap on arg0, arg1 nested within i0
|
||||
// CHECK: dma_start %arg0[
|
||||
// CHECK: dma_start %arg1[
|
||||
// CHECK-NEXT for %i1 = 1 to 7 {
|
||||
for %i1 = 0 to 7 {
|
||||
%7 = affine_apply #map1(%i0, %i1)
|
||||
%8 = affine_apply #map2(%i1)
|
||||
dma_start %arg0[%7#0, %7#1], %0[%c0, %c0], %num_elts, %3[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
dma_start %arg1[%8#0, %8#1], %1[%c0, %c0], %num_elts, %4[%c0] : memref<512x32xvector<8xf32>, #map0>, memref<64x4xvector<8xf32>, #map0, 2>, memref<2xi32>
|
||||
dma_wait %3[%c0], %num_elts : memref<2xi32>
|
||||
dma_wait %4[%c0], %num_elts : memref<2xi32>
|
||||
// Steady state for DMA overlap on arg0, arg1
|
||||
// CHECK: dma_start %arg0[
|
||||
// CHECK: dma_start %arg1[
|
||||
// CHECK: dma_wait %3[
|
||||
// CHECK: dma_wait %2[
|
||||
// CHECK-NEXT: for %i2 = 0 to 3 {
|
||||
for %i2 = 0 to 3 {
|
||||
"foo"() : () -> ()
|
||||
}
|
||||
}
|
||||
// epilogue for arg0, arg1
|
||||
// CHECK: dma_wait %3[
|
||||
// CHECK: dma_wait %2[
|
||||
|
||||
// epilogue for DMA overlap on %arg2
|
||||
// CHECK: dma_wait %0[%37, %c0_2], %c256 : memref<2x2xi32>
|
||||
// Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
|
||||
// CHECK: dma_start %arg0[
|
||||
// CHECK: dma_start %arg1[
|
||||
// CHECK: for %i4 = 1 to 7 {
|
||||
// CHECK: dma_start %arg0[
|
||||
// CHECK: dma_start %arg1[
|
||||
// CHECK: dma_wait %3[
|
||||
// CHECK: dma_wait %2[
|
||||
// CHECK: for %i5 = 0 to 3 {
|
||||
// CHECK: "foo"() : () -> ()
|
||||
// CHECK: dma_wait %3[
|
||||
// CHECK: dma_wait %2[
|
||||
// CHECK: for %i6 = 0 to 3 {
|
||||
|
||||
// The DMAs below are outgoing DMAs on arg2, not yet overlapped.
|
||||
// CHECK: dma_start %1{{.*}}, %arg2[
|
||||
// CHECK-NEXT: dma_wait %0[
|
||||
dma_start %2[%c0, %c0], %arg2[%6#0, %6#1], %num_elts, %5[%c0] : memref<64x4xvector<8xf32>, #map0, 2>, memref<512x32xvector<8xf32>, #map0>, memref<2xi32>
|
||||
dma_wait %5[%c0], %num_elts : memref<2xi32>
|
||||
} // CHECK }
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user