mirror of
https://github.com/intel/llvm.git
synced 2026-01-19 01:15:50 +08:00
[mlir][Tensor] Implement reifyReturnTypeShapesPerResultDim for tensor.insert_slice.
Differential Revision: https://reviews.llvm.org/D105852
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
include "mlir/Dialect/Tensor/IR/TensorBase.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
@@ -99,8 +100,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "int64_t":$index)>,
|
||||
OpBuilder<(ins "Value":$source, "Value":$index)>
|
||||
OpBuilder<(ins "Value":$source, "int64_t":$index)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
@@ -432,6 +432,8 @@ def Tensor_InsertOp : Tensor_Op<"insert",
|
||||
def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
|
||||
Tensor_Dialect, "insert_slice",
|
||||
[NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface,
|
||||
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
|
||||
["reifyReturnTypeShapesPerResultDim"]>,
|
||||
TypesMatchWith<"expected result type to match dest type",
|
||||
"dest", "result", "$_self">]> {
|
||||
let summary = "insert_slice operation";
|
||||
|
||||
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRTensor
|
||||
MLIRCastInterfaces
|
||||
MLIRDialectUtils
|
||||
MLIRIR
|
||||
MLIRInferTypeOpInterface
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRSupport
|
||||
MLIRStandard
|
||||
|
||||
@@ -203,12 +203,6 @@ void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
build(builder, result, source, indexValue);
|
||||
}
|
||||
|
||||
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
|
||||
Value index) {
|
||||
auto indexTy = builder.getIndexType();
|
||||
build(builder, result, indexTy, source, index);
|
||||
}
|
||||
|
||||
Optional<int64_t> DimOp::getConstantIndex() {
|
||||
if (auto constantOp = index().getDefiningOp<ConstantOp>())
|
||||
return constantOp.getValue().cast<IntegerAttr>().getInt();
|
||||
@@ -1048,6 +1042,17 @@ OpFoldResult InsertSliceOp::fold(ArrayRef<Attribute>) {
|
||||
return OpFoldResult();
|
||||
}
|
||||
|
||||
LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim(
|
||||
OpBuilder &builder,
|
||||
SmallVectorImpl<SmallVector<Value>> &reifiedReturnShapes) {
|
||||
reifiedReturnShapes.resize(1, SmallVector<Value>(getType().getRank()));
|
||||
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
|
||||
reifiedReturnShapes[0][dim] =
|
||||
builder.createOrFold<tensor::DimOp>(getLoc(), dest(), dim);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Pattern to rewrite a insert_slice op with constant arguments.
|
||||
class InsertSliceOpConstantArgumentFolder final
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s
|
||||
|
||||
func @insert_slice(
|
||||
%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>,
|
||||
%arg2 : index, %arg3 : index, %arg4 : index) -> (index, index, index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
|
||||
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
|
||||
%d2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
|
||||
%0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [%d0, %d1, %d2] [1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?xf32>
|
||||
%1 = tensor.dim %0, %c0 : tensor<?x?x?xf32>
|
||||
%2 = tensor.dim %0, %c1 : tensor<?x?x?xf32>
|
||||
%3 = tensor.dim %0, %c2 : tensor<?x?x?xf32>
|
||||
return %1, %2, %3 : index, index, index
|
||||
}
|
||||
// CHECK-LABEL: func @insert_slice(
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
|
||||
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
|
||||
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
|
||||
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
|
||||
// CHECK: return %[[D0]], %[[D1]], %[[D2]]
|
||||
@@ -3761,6 +3761,7 @@ td_library(
|
||||
deps = [
|
||||
":CastInterfacesTdFiles",
|
||||
":ControlFlowInterfacesTdFiles",
|
||||
":InferTypeOpInterfaceTdFiles",
|
||||
":OpBaseTdFiles",
|
||||
":SideEffectInterfacesTdFiles",
|
||||
":ViewLikeInterfaceTdFiles",
|
||||
@@ -3814,6 +3815,7 @@ cc_library(
|
||||
":ControlFlowInterfaces",
|
||||
":DialectUtils",
|
||||
":IR",
|
||||
":InferTypeOpInterface",
|
||||
":SideEffectInterfaces",
|
||||
":StandardOps",
|
||||
":Support",
|
||||
|
||||
Reference in New Issue
Block a user