[mlir][Tensor] Implement reifyReturnTypeShapesPerResultDim for tensor.insert_slice.

Differential Revision: https://reviews.llvm.org/D105852
This commit is contained in:
MaheshRavishankar
2021-07-13 14:51:20 -07:00
parent 123e8dfcf8
commit f2b5e438aa
6 changed files with 46 additions and 8 deletions

View File

@@ -15,6 +15,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

View File

@@ -12,6 +12,7 @@
include "mlir/Dialect/Tensor/IR/TensorBase.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -99,8 +100,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
}];
let builders = [
OpBuilder<(ins "Value":$source, "int64_t":$index)>,
OpBuilder<(ins "Value":$source, "Value":$index)>
OpBuilder<(ins "Value":$source, "int64_t":$index)>
];
let extraClassDeclaration = [{
@@ -432,6 +432,8 @@ def Tensor_InsertOp : Tensor_Op<"insert",
def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
Tensor_Dialect, "insert_slice",
[NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>,
TypesMatchWith<"expected result type to match dest type",
"dest", "result", "$_self">]> {
let summary = "insert_slice operation";

View File

@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRTensor
MLIRCastInterfaces
MLIRDialectUtils
MLIRIR
MLIRInferTypeOpInterface
MLIRSideEffectInterfaces
MLIRSupport
MLIRStandard

View File

@@ -203,12 +203,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
build(builder, result, source, indexValue);
}
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
Value index) {
auto indexTy = builder.getIndexType();
build(builder, result, indexTy, source, index);
}
Optional<int64_t> DimOp::getConstantIndex() {
if (auto constantOp = index().getDefiningOp<ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
@@ -1048,6 +1042,17 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
return OpFoldResult();
}
LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim(
OpBuilder &builder,
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
reifiedReturnShapes[0][dim] =
builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
}
return success();
}
namespace {
/// Pattern to rewrite a insert_slice op with constant arguments.
class InsertSliceOpConstantArgumentFolder final

View File

@@ -0,0 +1,27 @@
// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
func @insert_slice(
%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>,
%arg2 : index, %arg3 : index, %arg4 : index) -> (index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
%0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [%d0, %d1, %d2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
%1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
%2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
%3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
return %1, %2, %3 : index, index, index
}
// CHECK-LABEL: func @insert_slice(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK: return %[[D0]], %[[D1]], %[[D2]]

View File

@@ -3761,6 +3761,7 @@ td_library(
deps = [
":CastInterfacesTdFiles",
":ControlFlowInterfacesTdFiles",
":InferTypeOpInterfaceTdFiles",
":OpBaseTdFiles",
":SideEffectInterfacesTdFiles",
":ViewLikeInterfaceTdFiles",
@@ -3814,6 +3815,7 @@ cc_library(
":ControlFlowInterfaces",
":DialectUtils",
":IR",
":InferTypeOpInterface",
":SideEffectInterfaces",
":StandardOps",
":Support",