[mlir][tensor][bufferize] Bufferize tensor.splat op

The op bufferizes similarly to tensor.generate: it is lowered to a linalg.map, which may then lower to a loop nest that fills the buffer.

Differential Revision: https://reviews.llvm.org/D150952
This commit is contained in:
Matthias Springer
2023-05-22 14:13:08 +02:00
parent a9e90f7994
commit 481b254e45
2 changed files with 66 additions and 0 deletions

View File

@@ -1087,6 +1087,54 @@ struct ParallelInsertSliceOpInterface
}
};
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
/// with a linalg.map. Similar to tensor.generate.
struct SplatOpInterface
: public BufferizableOpInterface::ExternalModel<SplatOpInterface,
tensor::SplatOp> {
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
return true;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto splatOp = cast<tensor::SplatOp>(op);
// Should the buffer be deallocated?
bool dealloc =
shouldDeallocateOpResult(cast<OpResult>(splatOp.getResult()), options);
// TODO: Implement memory space for this op.
if (options.defaultMemorySpace != Attribute())
return op->emitError("memory space not implemented yet");
// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc =
allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(),
/*escape=*/!dealloc, options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
// Create linalg::MapOp.
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
auto linalgOp =
rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
/*init=*/*tensorAlloc);
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
// Create linalg::IndexOps.
rewriter.setInsertionPointToStart(&linalgBody);
rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
return success();
}
};
} // namespace
} // namespace tensor
} // namespace mlir
@@ -1110,6 +1158,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
*ctx);
RankOp::attachInterface<RankOpInterface>(*ctx);
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
SplatOp::attachInterface<SplatOpInterface>(*ctx);
// Load additional dialects of which ops may get created.
ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();

View File

@@ -582,3 +582,20 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
// CHECK: return %[[r]] : tensor<?x?xindex>
return %0 : tensor<?x?xindex>
}
// -----
// CHECK-LABEL: func @tensor.splat(
// CHECK-SAME: %[[F:.*]]: f32)
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32>
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
// CHECK: %[[MAPPED:.*]] = linalg.map
// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>)
// CHECK: linalg.yield %[[F]]
// CHECK: }
// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32>
// CHECK: }
func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
%t = tensor.splat %f : tensor<10x2x4xf32>
return %t : tensor<10x2x4xf32>
}