From cc145f40530667d65220536a3e03eabe9fdd46cf Mon Sep 17 00:00:00 2001 From: Cullen Rhodes Date: Thu, 20 Jun 2024 08:07:43 +0100 Subject: [PATCH] [mlir][vector] Disable Gather1DToConditionalLoads for scalable vectors (#96049) Pattern scalarizes vector.gather operations and is incorrect for scalable vectors. --- .../Dialect/Vector/Transforms/LowerVectorGather.cpp | 3 +++ mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 90128126d0fa..dd027d107d16 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -189,6 +189,9 @@ struct Gather1DToConditionalLoads : OpRewritePattern { if (resultTy.getRank() != 1) return rewriter.notifyMatchFailure(op, "unsupported rank"); + if (resultTy.isScalable()) + return rewriter.notifyMatchFailure(op, "not a fixed-width vector"); + Location loc = op.getLoc(); Type elemTy = resultTy.getElementType(); // Vector type with a single element. Used to generate `vector.loads`. diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index d047ac629d87..c2eb88afa4db 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -206,3 +206,13 @@ func.func @strided_gather(%base : memref<100x3xf32>, // CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>) // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32> // CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32> + +// CHECK-LABEL: @scalable_gather_1d +// CHECK-NOT: extract +// CHECK: vector.gather +// CHECK-NOT: extract +func.func @scalable_gather_1d(%base: tensor, %v: vector<[2]xindex>, %mask: vector<[2]xi1>, %pass_thru: vector<[2]xf32>) -> vector<[2]xf32> { + %c0 = arith.constant 0 : index + %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor, vector<[2]xindex>, vector<[2]xi1>, vector<[2]xf32> into vector<[2]xf32> + return %0 : vector<[2]xf32> +}