[mlir][vector][bufferize] Bufferize vector.mask and vector.yield

The masked op can currently not bufferize out-of-place. Such IR would be rejected by the One-Shot Bufferize because it would mean that a new buffer allocation is yielded from a block. Furthermore, only one operation is currently allowed inside `vector.mask`.

Differential Revision: https://reviews.llvm.org/D141686
This commit is contained in:
Matthias Springer
2023-01-31 08:56:03 +01:00
parent c1fa8179d4
commit 199f368e35
5 changed files with 188 additions and 0 deletions

View File

@@ -496,6 +496,15 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
cast<ToMemrefOp>(op));
}
// Remove all dead to_tensor ops.
op->walk<WalkOrder::PostOrder>([&](ToTensorOp toTensorOp) {
if (toTensorOp->getUses().empty()) {
rewriter.eraseOp(toTensorOp);
return WalkResult::skip();
}
return WalkResult::advance();
});
/// Check the result of bufferization. Return an error if an op was not
/// bufferized, unless partial bufferization is allowed.
if (options.allowUnknownOps)

View File

@@ -9,6 +9,7 @@
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Dialect.h"
@@ -131,6 +132,158 @@ struct GatherOpInterface
}
};
/// Bufferization of vector.mask. Replaced with a new vector.mask that
/// operates on a memref.
struct MaskOpInterface
: public BufferizableOpInterface::ExternalModel<MaskOpInterface,
vector::MaskOp> {
SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
const AnalysisState &state) const {
// MaskOps do not have tensor OpOperands. The yielded values are the result
// of the wrapped op.
auto maskOp = cast<vector::MaskOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), opResult));
auto yieldOp =
cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
return {&yieldOp->getOpOperand(resultNum)};
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
const AnalysisState &state) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
return failure();
// TODO: Remove this function when vector.mask bodies can bufferize
// out-of-place. This is currently not supported because yielding allocs
// from a block leads to a memory leak and because vector.mask supports only
// a single op in its body.
auto maskOp = cast<vector::MaskOp>(op);
if (!maskOp.getMaskRegion()
.front()
.getOps<bufferization::AllocTensorOp>()
.empty())
return op->emitOpError("body must bufferize in-place");
return success();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto maskOp = cast<vector::MaskOp>(op);
// Do not bufferize if the masked op is not bufferizable.
Operation *maskedOp = maskOp.getMaskableOp();
if (!options.dynCastBufferizableOp(maskedOp))
return success();
// Update the terminator: Drop all operands that are not results of the
// masked op.
auto yieldOp =
cast<vector::YieldOp>(maskOp.getMaskRegion().front().getTerminator());
SmallVector<Value> newReturnValues(maskOp->getNumResults(), Value());
SmallVector<Value> newYieldedValues;
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
if (llvm::find(maskedOp->getOpResults(), it.value()) !=
maskedOp->getOpResults().end()) {
newYieldedValues.push_back(it.value());
} else {
// This used to be a tensor result of the masked op, but is now a memref
// that is defined outside of the vector.mask op.
newReturnValues[it.index()] = it.value();
}
}
rewriter.updateRootInPlace(yieldOp, [&]() {
yieldOp.getOperandsMutable().assign(newYieldedValues);
});
// Create a new vector.mask op.
TypeRange newResultTypes(newYieldedValues);
auto newOp = rewriter.create<vector::MaskOp>(
op->getLoc(), newResultTypes, maskOp.getMask(), maskOp.getPassthru(),
/*maskableOp=*/nullptr,
/*maskRegionBuilder=*/[](OpBuilder &b, Operation *) {});
newOp.getRegion().takeBody(maskOp.getMaskRegion());
// Replace all uses of the old vector.mask op.
int idx = 0;
for (int i = 0; i < static_cast<int>(maskOp->getNumResults()); ++i) {
if (!newReturnValues[i])
newReturnValues[i] = newOp->getResult(idx++);
}
replaceOpWithBufferizedValues(rewriter, maskOp, newReturnValues);
return success();
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::Equivalent;
}
};
/// Bufferization of vector.yield. Replaced with a new vector.yield that
/// operates on a memref.
struct YieldOpInterface
: public BufferizableOpInterface::ExternalModel<YieldOpInterface,
vector::YieldOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
}
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// Yield operands always bufferize inplace. Otherwise, an alloc + copy
// may be generated inside the block. We should not return/yield allocations
// when possible.
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto yieldOp = cast<vector::YieldOp>(op);
// Only supported as a vector.mask terminator.
auto maskOp = dyn_cast<vector::MaskOp>(yieldOp->getParentOp());
if (!maskOp)
return yieldOp->emitError("unsupported vector::YieldOp parent");
// Do not bufferize if the masked op is not bufferizable.
Operation *maskedOp = &maskOp.getMaskRegion().front().front();
if (!options.dynCastBufferizableOp(maskedOp))
return success();
// Create a new terminator with the same number of operands. Some of these
// may get dropped during the bufferization of vector.mask.
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
if (value.getType().isa<TensorType>()) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
newResults.push_back(*maybeBuffer);
} else {
newResults.push_back(value);
}
}
replaceOpWithNewBufferizedOp<vector::YieldOp>(rewriter, op, newResults);
return success();
}
};
} // namespace
} // namespace vector
} // namespace mlir
@@ -141,5 +294,7 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
TransferReadOp::attachInterface<TransferReadOpInterface>(*ctx);
TransferWriteOp::attachInterface<TransferWriteOpInterface>(*ctx);
GatherOp::attachInterface<GatherOpInterface>(*ctx);
MaskOp::attachInterface<MaskOpInterface>(*ctx);
YieldOp::attachInterface<YieldOpInterface>(*ctx);
});
}

View File

@@ -0,0 +1,9 @@
// RUN: mlir-opt %s -vector-bufferize -split-input-file -verify-diagnostics
// | FileCheck %s
// CHECK-LABEL: func @mask(
func.func @mask(%t0: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?xf32> {
// expected-error @+1 {{'vector.mask' op body must bufferize in-place}}
%0 = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
return %0 : tensor<?xf32>
}

View File

@@ -43,3 +43,6 @@ func.func @gather(%base: tensor<?x?xf32>, %v: vector<16xi32>, %mask: vector<16xi
%0 = vector.gather %base[%c0, %c0][%v], %mask, %pass_thru : tensor<?x?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
return %0 : vector<16xf32>
}
// TODO: Add test case for vector.mask. The masked op can currently not
// bufferize out-of-place, so the only test case is in one-shot-bufferize.mlir.

View File

@@ -0,0 +1,12 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries" -split-input-file | FileCheck %s
// CHECK-LABEL: func @mask(
// CHECK-SAME: %[[t0:.*]]: memref<?xf32, strided<[?], offset: ?>>
func.func @mask(%t0: tensor<?xf32>, %val: vector<16xf32>, %idx: index, %m0: vector<16xi1>) -> tensor<?xf32> {
// CHECK-NOT: alloc
// CHECK-NOT: copy
// CHECK: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %[[t0]][%{{.*}}] : vector<16xf32>, memref<?xf32, strided<[?], offset: ?>> } : vector<16xi1>
%0 = vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
// CHECK: return %[[t0]]
return %0 : tensor<?xf32>
}