[MLIR][Standard] Add documentation for std.dim and fix test cases

Apply post-commit suggestions (see https://reviews.llvm.org/D81551).
Add documentation, simplify, and fix test cases.

Differential Revision: https://reviews.llvm.org/D81722
This commit is contained in:
Frederik Gossen
2020-06-15 10:39:05 +00:00
parent 9baba7cf66
commit 361f664850
5 changed files with 16 additions and 16 deletions

View File

@@ -1376,6 +1376,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
The `dim` operation takes a memref/tensor and a dimension operand of type
`index`.
It returns the size of the requested dimension of the given memref/tensor.
If the dimension index is out of bounds the behavior is undefined.
The specified memref or tensor type is that of the first operand.

View File

@@ -2118,7 +2118,7 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
Optional<int64_t> index = dimOp.getConstantIndex();
if (!index.hasValue()) {
// TODO(frgossen): Implement this lowering.
// TODO: Implement this lowering.
return failure();
}

View File

@@ -17,6 +17,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
@@ -206,14 +207,10 @@ static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
assert(index.hasValue() &&
"expect only `dim` operations with a constant index");
int64_t i = index.getValue();
if (auto viewOp = dyn_cast<ViewOp>(dimOp.memrefOrTensor().getDefiningOp()))
return isMemRefSizeValidSymbol<ViewOp>(viewOp, i, region);
if (auto subViewOp =
dyn_cast<SubViewOp>(dimOp.memrefOrTensor().getDefiningOp()))
return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, i, region);
if (auto allocOp = dyn_cast<AllocOp>(dimOp.memrefOrTensor().getDefiningOp()))
return isMemRefSizeValidSymbol<AllocOp>(allocOp, i, region);
return false;
return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
.Case<ViewOp, SubViewOp, AllocOp>(
[&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
.Default([](Operation *) { return false; });
}
// A value can be used as a symbol (at all its use sites) iff it meets one of

View File

@@ -1273,10 +1273,8 @@ void DimOp::build(OpBuilder &builder, OperationState &result,
}
Optional<int64_t> DimOp::getConstantIndex() {
auto constantOp = index().getDefiningOp<ConstantOp>();
if (constantOp) {
if (auto constantOp = index().getDefiningOp<ConstantOp>())
return constantOp.getValue().cast<IntegerAttr>().getInt();
}
return {};
}

View File

@@ -233,7 +233,8 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
// CHECK-DAG: %[[splat:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// CHECK-DAG: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>>
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], %c0 : memref<?x?xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 3 {
// CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cond1:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
@@ -248,8 +249,9 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
// FULL-UNROLL: %[[pad:.*]] = constant 7.000000e+00 : f32
// FULL-UNROLL: %[[VEC0:.*]] = constant dense<7.000000e+00> : vector<3x15xf32>
// FULL-UNROLL: %[[C0:.*]] = constant 0 : index
// FULL-UNROLL: %[[SPLAT:.*]] = constant dense<7.000000e+00> : vector<15xf32>
// FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %c0 : memref<?x?xf32>
// FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
// FULL-UNROLL: cmpi "slt", %[[base]], %[[DIM]] : index
// FULL-UNROLL: %[[VEC1:.*]] = scf.if %{{.*}} -> (vector<3x15xf32>) {
// FULL-UNROLL: vector.transfer_read %[[A]][%[[base]], %[[base]]], %[[pad]] : memref<?x?xf32>, vector<15xf32>
@@ -304,10 +306,11 @@ func @transfer_read_progressive(%A : memref<?x?xf32>, %base: index) -> vector<3x
// FULL-UNROLL-SAME: %[[base:[a-zA-Z0-9]+]]: index,
// FULL-UNROLL-SAME: %[[vec:[a-zA-Z0-9]+]]: vector<3x15xf32>
func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vector<3x15xf32>) {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[alloc:.*]] = alloca() {alignment = 128 : i64} : memref<3xvector<15xf32>>
// CHECK: %[[vmemref:.*]] = vector.type_cast %[[alloc]] : memref<3xvector<15xf32>> to memref<vector<3x15xf32>>
// CHECK: store %[[vec]], %[[vmemref]][] : memref<vector<3x15xf32>>
// CHECK: %[[dim:.*]] = dim %[[A]], %c0 : memref<?x?xf32>
// CHECK: %[[dim:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
// CHECK: affine.for %[[I:.*]] = 0 to 3 {
// CHECK: %[[add:.*]] = affine.apply #[[$MAP0]](%[[I]])[%[[base]]]
// CHECK: %[[cmp:.*]] = cmpi "slt", %[[add]], %[[dim]] : index
@@ -316,7 +319,8 @@ func @transfer_write_progressive(%A : memref<?x?xf32>, %base: index, %vec: vecto
// CHECK: vector.transfer_write %[[vec_1d]], %[[A]][%[[add]], %[[base]]] : vector<15xf32>, memref<?x?xf32>
// CHECK: }
// FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %c0 : memref<?x?xf32>
// FULL-UNROLL: %[[C0:.*]] = constant 0 : index
// FULL-UNROLL: %[[DIM:.*]] = dim %[[A]], %[[C0]] : memref<?x?xf32>
// FULL-UNROLL: %[[CMP0:.*]] = cmpi "slt", %[[base]], %[[DIM]] : index
// FULL-UNROLL: scf.if %[[CMP0]] {
// FULL-UNROLL: %[[V0:.*]] = vector.extract %[[vec]][0] : vector<3x15xf32>