[mlir][linalg] Expose function to create op on buffers during bufferization.

Differential Revision: https://reviews.llvm.org/D109140
This commit is contained in:
Alexander Belyaev
2021-09-02 11:06:49 +02:00
parent d581d94385
commit f68de11c10
2 changed files with 41 additions and 53 deletions

View File

@@ -80,6 +80,12 @@ void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
RewritePatternSet &patterns);
/// Create linalg op on buffers given the original tensor-based operation and
/// the buffers for the outputs.
LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
LinalgOp linalgOp, ValueRange inputs,
ValueRange outputs);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns);

View File

@@ -73,56 +73,44 @@ allocateBuffersForResults(Location loc, LinalgOp linalgOp, ValueRange outputs,
return success();
}
/// Specialization for `linalg::GenericOp`.
/// A pattern to convert Generic Linalg operations which work on tensors to
/// use buffers. BufferPlacement pass should be later used to move
/// Alloc operations to the correct positions and insert the missing Dealloc
/// operations in the correct places.
static void
finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter,
GenericOp genericOp, ValueRange inputs,
ValueRange outputs) {
// Generate a new linalg operation that works on buffers.
auto newGenericOp = rewriter.create<GenericOp>(
genericOp.getLoc(),
/*resultTensorTypes=*/llvm::None,
/*inputs=*/inputs,
/*outputs=*/outputs, genericOp.indexing_maps(),
genericOp.iterator_types(), genericOp.docAttr(),
genericOp.library_callAttr());
/// Create linalg op on buffers given the original tensor-based operation and
/// the buffers for the outputs.
LinalgOp
mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter,
LinalgOp linalgOp, ValueRange inputs,
ValueRange outputs) {
if (auto genericOp = mlir::dyn_cast<GenericOp>(*linalgOp)) {
// Generate a new linalg operation that works on buffers.
auto newGenericOp = rewriter.create<GenericOp>(
genericOp.getLoc(),
/*resultTensorTypes=*/llvm::None,
/*inputs=*/inputs,
/*outputs=*/outputs, genericOp.indexing_maps(),
genericOp.iterator_types(), genericOp.docAttr(),
genericOp.library_callAttr());
// Create a new block in the region of the new Generic Op.
Block *oldBlock = genericOp.getBody();
Region &newRegion = newGenericOp.region();
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
oldBlock->getArgumentTypes());
// Create a new block in the region of the new Generic Op.
Block *oldBlock = genericOp.getBody();
Region &newRegion = newGenericOp.region();
Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(),
oldBlock->getArgumentTypes());
// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
mapping.map(oldBlock->getArguments(), newBlock->getArguments());
// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
mapping.map(oldBlock->getArguments(), newBlock->getArguments());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
for (auto &op : oldBlock->getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
for (auto &op : oldBlock->getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
}
return newGenericOp;
}
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(genericOp, outputs);
}
/// Specialization for all other `linalg::LinalgOp`.
static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter,
linalg::LinalgOp linalgOp,
ValueRange inputs, ValueRange outputs) {
assert(!isa<linalg::GenericOp>(linalgOp.getOperation()));
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
linalgOp.clone(rewriter, linalgOp.getLoc(),
/*resultTypes=*/ArrayRef<Type>{}, newOperands);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(linalgOp, outputs);
return linalgOp.clone(rewriter, linalgOp.getLoc(),
/*resultTypes=*/ArrayRef<Type>{}, newOperands);
}
//===----------------------------------------------------------------------===//
@@ -218,15 +206,9 @@ public:
return op.emitOpError()
<< "Failed to allocate buffers for tensor results.";
}
// Delegate to the linalg generic pattern.
if (auto genericOp = dyn_cast<linalg::GenericOp>(*op)) {
finalizeBufferAllocationForGenericOp(rewriter, genericOp,
adaptor.inputs(), newOutputBuffers);
return success();
}
finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers);
createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};