mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 06:40:01 +08:00
[mlir][tensor][bufferize] Bufferize tensor.splat op
The op bufferizes similarly to tensor.generate: it is lowered to a linalg.map, which may then lower to a loop nest that fills the buffer. Differential Revision: https://reviews.llvm.org/D150952
This commit is contained in:
@@ -1087,6 +1087,54 @@ struct ParallelInsertSliceOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
/// Bufferization of tensor.splat. Bufferizes to a new allocation that is filled
|
||||
/// with a linalg.map. Similar to tensor.generate.
|
||||
struct SplatOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<SplatOpInterface,
|
||||
tensor::SplatOp> {
|
||||
|
||||
bool bufferizesToAllocation(Operation *op, OpResult opResult) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto splatOp = cast<tensor::SplatOp>(op);
|
||||
|
||||
// Should the buffer be deallocated?
|
||||
bool dealloc =
|
||||
shouldDeallocateOpResult(cast<OpResult>(splatOp.getResult()), options);
|
||||
|
||||
// TODO: Implement memory space for this op.
|
||||
if (options.defaultMemorySpace != Attribute())
|
||||
return op->emitError("memory space not implemented yet");
|
||||
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
FailureOr<Value> tensorAlloc =
|
||||
allocateTensorForShapedValue(rewriter, loc, splatOp.getResult(),
|
||||
/*escape=*/!dealloc, options,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
|
||||
// Create linalg::MapOp.
|
||||
auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
|
||||
auto linalgOp =
|
||||
rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
|
||||
/*init=*/*tensorAlloc);
|
||||
Block &linalgBody = linalgOp.getMapper().emplaceBlock();
|
||||
|
||||
// Create linalg::IndexOps.
|
||||
rewriter.setInsertionPointToStart(&linalgBody);
|
||||
rewriter.create<linalg::YieldOp>(loc, splatOp.getInput());
|
||||
rewriter.replaceOp(splatOp, linalgOp.getResult()[0]);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace tensor
|
||||
} // namespace mlir
|
||||
@@ -1110,6 +1158,7 @@ void mlir::tensor::registerBufferizableOpInterfaceExternalModels(
|
||||
*ctx);
|
||||
RankOp::attachInterface<RankOpInterface>(*ctx);
|
||||
ReshapeOp::attachInterface<ReshapeOpInterface>(*ctx);
|
||||
SplatOp::attachInterface<SplatOpInterface>(*ctx);
|
||||
|
||||
// Load additional dialects of which ops may get created.
|
||||
ctx->loadDialect<arith::ArithDialect, linalg::LinalgDialect>();
|
||||
|
||||
@@ -582,3 +582,20 @@ func.func @tensor.pad(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
|
||||
// CHECK: return %[[r]] : tensor<?x?xindex>
|
||||
return %0 : tensor<?x?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tensor.splat(
|
||||
// CHECK-SAME: %[[F:.*]]: f32)
|
||||
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() {{.*}} : memref<10x2x4xf32>
|
||||
// CHECK: %[[ALLOC_T:.*]] = bufferization.to_tensor %[[ALLOC]]
|
||||
// CHECK: %[[MAPPED:.*]] = linalg.map
|
||||
// CHECK: outs(%[[ALLOC_T]] : tensor<10x2x4xf32>)
|
||||
// CHECK: linalg.yield %[[F]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[MAPPED]] : tensor<10x2x4xf32>
|
||||
// CHECK: }
|
||||
func.func @tensor.splat(%f: f32) -> tensor<10x2x4xf32> {
|
||||
%t = tensor.splat %f : tensor<10x2x4xf32>
|
||||
return %t : tensor<10x2x4xf32>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user