Files
llvm/mlir/lib/Transforms/PipelineDataTransfer.cpp
Nicolas Vasilache ce2edea135 [MLIR] Cleanup AffineExpr
This CL introduces a series of cleanups for AffineExpr value types:
1. to make it clear that the value types should be used, the pointer
AffineExpr types are put in the detail namespace. Unfortunately, since the
value type operator-> only forwards to the underlying pointer type, we
still
need to expose this in the include file for now;
2. AffineExprKind is ok to use, it thus comes out of detail and thus of
AffineExpr
3. getAffineDimExpr, getAffineSymbolExpr, getAffineConstantExpr are
similarly
extracted as free functions and their naming is mande consistent across
Builder, MLContext and AffineExpr
4. AffineBinaryOpEx::simplify functions are made into static free
functions.
In particular it is moved away from AffineMap.cpp where it does not belong
5. operator AffineExprType is made explicit
6. uses the binary operators everywhere possible
7. drops the pointer usage everywhere outside of AffineExpr.cpp,
MLIRContext.cpp and AsmPrinter.cpp

PiperOrigin-RevId: 216207212
2019-03-29 13:24:45 -07:00

286 lines
11 KiB
C++

//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass to pipeline data transfers.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Passes.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseMap.h"
using namespace mlir;
namespace {
struct PipelineDataTransfer : public MLFunctionPass {
explicit PipelineDataTransfer() {}
PassResult runOnMLFunction(MLFunction *f) override;
};
} // end anonymous namespace
/// Creates a pass to pipeline explicit movement of data across levels of the
/// memory hierarchy.
MLFunctionPass *mlir::createPipelineDataTransferPass() {
return new PipelineDataTransfer();
}
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's or
// op traits for it are added. TODO(b/117228571)
static bool isDmaStartStmt(const OperationStmt &stmt) {
return stmt.getName().strref().contains("dma.in.start") ||
stmt.getName().strref().contains("dma.out.start");
}
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static bool isDmaFinishStmt(const OperationStmt &stmt) {
return stmt.getName().strref().contains("dma.finish");
}
/// Given a DMA start operation, returns the operand position of either the
/// source or destination memref depending on the one that is at the higher
/// level of the memory hierarchy.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
static unsigned getHigherMemRefPos(const OperationStmt &dmaStartStmt) {
assert(isDmaStartStmt(dmaStartStmt));
unsigned srcDmaPos = 0;
unsigned destDmaPos =
cast<MemRefType>(dmaStartStmt.getOperand(0)->getType())->getRank() + 1;
if (cast<MemRefType>(dmaStartStmt.getOperand(srcDmaPos)->getType())
->getMemorySpace() >
cast<MemRefType>(dmaStartStmt.getOperand(destDmaPos)->getType())
->getMemorySpace())
return srcDmaPos;
return destDmaPos;
}
// Returns the position of the tag memref operand given a DMA statement.
// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
// added. TODO(b/117228571)
unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
assert(isDmaStartStmt(dmaStmt) || isDmaFinishStmt(dmaStmt));
if (isDmaStartStmt(dmaStmt)) {
// Second to last operand.
return dmaStmt.getNumOperands() - 2;
}
// First operand for a dma finish statement.
return 0;
}
/// Doubles the buffer of the supplied memref.
static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
// Doubles the shape with a leading dimension extent of 2.
auto doubleShape = [&](MemRefType *origMemRefType) -> MemRefType * {
// Add the leading dimension in the shape for the double buffer.
ArrayRef<int> shape = origMemRefType->getShape();
SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
shapeSizes.insert(shapeSizes.begin(), 2);
auto *newMemRefType = bInner.getMemRefType(shapeSizes, bInner.getF32Type());
return newMemRefType;
};
auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
// Create and place the alloc at the top level.
auto *func = forStmt->getFunction();
MLFuncBuilder topBuilder(func, func->begin());
auto *newMemRef = cast<MLValue>(
topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
->getResult());
auto d0 = bInner.getAffineDimExpr(0);
auto *modTwoMap = bInner.getAffineMap(1, 0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef, ivModTwoOp->getResult(0)))
return false;
// We don't need ivMod2Op any more - this is cloned by
// replaceAllMemRefUsesWith wherever the memref replacement happens. Once
// b/117159533 is addressed, we'll eventually only need to pass
// ivModTwoOp->getResult(0) to replaceAllMemRefUsesWith.
cast<OperationStmt>(ivModTwoOp->getOperation())->eraseFromBlock();
return true;
}
// For testing purposes, this just runs on the first for statement of an
// MLFunction at the top level.
// TODO(bondhugula): upgrade this to scan all the relevant 'for' statements when
// the other TODOs listed inside are dealt with.
PassResult PipelineDataTransfer::runOnMLFunction(MLFunction *f) {
if (f->empty())
return PassResult::Success;
ForStmt *forStmt = nullptr;
for (auto &stmt : *f) {
if ((forStmt = dyn_cast<ForStmt>(&stmt))) {
break;
}
}
if (!forStmt)
return PassResult::Success;
unsigned numStmts = forStmt->getStatements().size();
if (numStmts == 0)
return PassResult::Success;
SmallVector<OperationStmt *, 4> dmaStartStmts;
SmallVector<OperationStmt *, 4> dmaFinishStmts;
for (auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt)
continue;
if (isDmaStartStmt(*opStmt)) {
dmaStartStmts.push_back(opStmt);
} else if (isDmaFinishStmt(*opStmt)) {
dmaFinishStmts.push_back(opStmt);
}
}
// TODO(bondhugula,andydavis): match tag memref's (requires memory-based
// subscript check utilities). Assume for now that start/finish are matched in
// the order they appear.
if (dmaStartStmts.size() != dmaFinishStmts.size())
return PassResult::Failure;
// Double the buffers for the higher memory space memref's.
// TODO(bondhugula): assuming we don't have multiple DMA starts for the same
// memref.
// TODO(bondhugula): check whether double-buffering is even necessary.
// TODO(bondhugula): make this work with different layouts: assuming here that
// the dimension we are adding here for the double buffering is the outermost
// dimension.
// Identify memref's to replace by scanning through all DMA start statements.
// A DMA start statement has two memref's - the one from the higher level of
// memory hierarchy is the one to double buffer.
for (auto *dmaStartStmt : dmaStartStmts) {
MLValue *oldMemRef = cast<MLValue>(
dmaStartStmt->getOperand(getHigherMemRefPos(*dmaStartStmt)));
if (!doubleBuffer(oldMemRef, forStmt))
return PassResult::Failure;
}
// Double the buffers for tag memref's.
for (auto *dmaFinishStmt : dmaFinishStmts) {
MLValue *oldTagMemRef = cast<MLValue>(
dmaFinishStmt->getOperand(getTagMemRefPos(*dmaFinishStmt)));
if (!doubleBuffer(oldTagMemRef, forStmt))
return PassResult::Failure;
}
// Collect all compute ops.
std::vector<const Statement *> computeOps;
computeOps.reserve(forStmt->getStatements().size());
// Store delay for statement for later lookup for AffineApplyOp's.
DenseMap<const Statement *, unsigned> opDelayMap;
for (const auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
if (!opStmt) {
// All for and if stmt's are treated as pure compute operations.
// TODO(bondhugula): check whether such statements do not have any DMAs
// nested within.
opDelayMap[&stmt] = 1;
} else if (isDmaStartStmt(*opStmt)) {
// DMA starts are not shifted.
opDelayMap[&stmt] = 0;
} else if (isDmaFinishStmt(*opStmt)) {
// DMA finish op shifted by one.
opDelayMap[&stmt] = 1;
} else if (!opStmt->is<AffineApplyOp>()) {
// Compute op shifted by one.
opDelayMap[&stmt] = 1;
computeOps.push_back(&stmt);
}
// Shifts for affine apply op's determined later.
}
// Get the ancestor of a 'stmt' that lies in forStmt's block.
auto getAncestorInForBlock =
[&](const Statement *stmt, const StmtBlock &block) -> const Statement * {
// Traverse up the statement hierarchy starting from the owner of operand to
// find the ancestor statement that resides in the block of 'forStmt'.
while (stmt != nullptr && stmt->getBlock() != &block) {
stmt = stmt->getParentStmt();
}
return stmt;
};
// Determine delays for affine apply op's: look up delay from its consumer op.
// This code will be thrown away once we have a way to obtain indices through
// a composed affine_apply op. See TODO(b/117159533). Such a composed
// affine_apply will be used exclusively by a given memref deferencing op.
for (const auto &stmt : *forStmt) {
auto *opStmt = dyn_cast<OperationStmt>(&stmt);
// Skip statements that aren't affine apply ops.
if (!opStmt || !opStmt->is<AffineApplyOp>())
continue;
// Traverse uses of each result of the affine apply op.
for (auto *res : opStmt->getResults()) {
for (auto &use : res->getUses()) {
auto *ancestorInForBlock =
getAncestorInForBlock(use.getOwner(), *forStmt);
assert(ancestorInForBlock &&
"traversing parent should reach forStmt block");
auto *opCheck = dyn_cast<OperationStmt>(ancestorInForBlock);
if (!opCheck || opCheck->is<AffineApplyOp>())
continue;
assert(opDelayMap.find(ancestorInForBlock) != opDelayMap.end());
if (opDelayMap.find(&stmt) != opDelayMap.end()) {
// This is where we enforce all uses of this affine_apply to have
// the same shifts - so that we know what shift to use for the
// affine_apply to preserve semantics.
assert(opDelayMap[&stmt] == opDelayMap[ancestorInForBlock]);
} else {
// Obtain delay from its consumer.
opDelayMap[&stmt] = opDelayMap[ancestorInForBlock];
}
}
}
}
// Get delays stored in map.
std::vector<uint64_t> delays(forStmt->getStatements().size());
unsigned s = 0;
for (const auto &stmt : *forStmt) {
delays[s++] = opDelayMap[&stmt];
}
if (!checkDominancePreservationOnShift(*forStmt, delays)) {
// Violates SSA dominance.
return PassResult::Failure;
}
if (stmtBodySkew(forStmt, delays))
return PassResult::Failure;
return PassResult::Success;
}