From e458434ebe87f890db0d4a03bbc3de30f3d052b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Tue, 12 Nov 2024 19:08:54 +0000 Subject: [PATCH] [mlir][vector] Restrict narrow-type-emulation patterns (#115612) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All patterns in populateVectorNarrowTypeEmulationPatterns currently assume a 1-D vector load/store rather than an n-D vector load/store. This assumption is evident in ConvertVectorTransferRead, for example, here (extracted from `ConvertVectorTransferRead`): ```cpp auto newRead = rewriter.create( loc, VectorType::get(numElements, newElementType), adaptor.getSource(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), newPadding); auto bitCast = rewriter.create( loc, VectorType::get(numElements * scale, oldElementType), newRead); ``` Both invocations of `VectorType::get()` here generate a 1-D vector. Attempts to use these patterns with more generic cases, such as 2-D vectors, fail. For example, trying to cast the following 2-D case to `i32`: ```mlir func.func @vector_maskedload_2d_i8_negative( %idx1: index, %idx2: index, %num_elems: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> { %0 = memref.alloc() : memref<3x4xi8> %mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1> %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8> return %1 : vector<2x4xi8> } ``` For example, casting to i32 produces: ```bash error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank %1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru : ^ ``` Instead of reworking these patterns (that's going to require much more effort), I’ve marked them as 1-D only and extended "TestEmulateNarrowTypePass" with an option to disable the Memref type converter - that's to be able to add negative tests (otherwise, the type converter throws an error we can't really test for). While not ideal, this workaround should suit a test pass. --- .../Transforms/VectorEmulateNarrowType.cpp | 25 ++++ .../emulate-narrow-type-unsupported.mlir | 111 ++++++++++++++++++ .../Dialect/MemRef/TestEmulateNarrowType.cpp | 11 +- 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 7578aadee23a..bb0731d768df 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -249,6 +249,11 @@ struct ConvertVectorStore final : OpConversionPattern { matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // See #115653 + if (op.getValueToStore().getType().getRank() != 1) + return rewriter.notifyMatchFailure(op, + "only 1-D vectors are supported ATM"); + auto loc = op.getLoc(); auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getValueToStore().getType().getElementType(); @@ -315,6 +320,11 @@ struct ConvertVectorMaskedStore final matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // See #115653 + if (op.getValueToStore().getType().getRank() != 1) + return rewriter.notifyMatchFailure(op, + "only 1-D vectors are supported ATM"); + auto loc = op.getLoc(); auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getValueToStore().getType().getElementType(); @@ -418,6 +428,11 @@ struct ConvertVectorLoad final : OpConversionPattern { matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // See #115653 + if (op.getVectorType().getRank() != 1) + return rewriter.notifyMatchFailure(op, + "only 1-D vectors are supported ATM"); + auto loc = op.getLoc(); auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getType().getElementType(); @@ -517,6 +532,11 @@ struct ConvertVectorMaskedLoad final matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // See #115653 + if (op.getVectorType().getRank() != 1) + return rewriter.notifyMatchFailure(op, + "only 1-D vectors are supported ATM"); + auto loc = op.getLoc(); auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getType().getElementType(); @@ -674,6 +694,11 @@ struct ConvertVectorTransferRead final matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // See #115653 + if (op.getVectorType().getRank() != 1) + return rewriter.notifyMatchFailure(op, + "only 1-D vectors are supported ATM"); + auto loc = op.getLoc(); auto convertedType = cast(adaptor.getSource().getType()); Type oldElementType = op.getType().getElementType(); diff --git a/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir new file mode 100644 index 000000000000..a5a6fc4acfe1 --- /dev/null +++ b/mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s + +// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D +// insted of 1-D vectors. That's currently not supported. + +///---------------------------------------------------------------------------------------- +/// vector.load +///---------------------------------------------------------------------------------------- + +func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> { + %0 = memref.alloc() : memref<3x4xi8> + %1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8> + return %1 : vector<2x4xi8> +} + +// No support for loading 2D vectors - expect no conversions +// CHECK-LABEL: func @vector_load_2d_i8_negative +// CHECK: memref.alloc() : memref<3x4xi8> +// CHECK-NOT: i32 + +// ----- + +///---------------------------------------------------------------------------------------- +/// vector.transfer_read +///---------------------------------------------------------------------------------------- + +func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> { + %c0 = arith.constant 0 : i4 + %0 = memref.alloc() : memref<3x8xi4> + %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} : + memref<3x8xi4>, vector<2x8xi4> + return %1 : vector<2x8xi4> +} +// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative +// CHECK: memref.alloc() : memref<3x8xi4> +// CHECK-NOT: i32 + +// ----- + +///---------------------------------------------------------------------------------------- +/// vector.maskedload +///---------------------------------------------------------------------------------------- + +func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> { + %0 = memref.alloc() : memref<3x4xi8> + %mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1> + %1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru : + memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8> + return %1 : vector<2x4xi8> +} + +// CHECK-LABEL: func @vector_maskedload_2d_i8_negative +// CHECK: memref.alloc() : memref<3x4xi8> +// CHECK-NOT: i32 + +// ----- + +///---------------------------------------------------------------------------------------- +/// vector.extract -> vector.masked_load +///---------------------------------------------------------------------------------------- + +func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> { + %0 = memref.alloc() : memref<8x8x16xi4> + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c8 = arith.constant 8 : index + %cst_1 = arith.constant dense<0> : vector<8x8x16xi4> + %cst_2 = arith.constant dense<0> : vector<8x16xi4> + %27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1> + %48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1> + %50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4> + %63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4> + return %63 : vector<8x8x16xi4> +} + +// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative +// CHECK: memref.alloc() : memref<8x8x16xi4> +// CHECK-NOT: i32 + +// ----- + +///---------------------------------------------------------------------------------------- +/// vector.store +///---------------------------------------------------------------------------------------- + +func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) { + %0 = memref.alloc() : memref<4x8xi8> + vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8> + return +} + +// CHECK-LABEL: func @vector_store_2d_i8_negative +// CHECK: memref.alloc() : memref<4x8xi8> +// CHECK-NOT: i32 + +// ----- + +///---------------------------------------------------------------------------------------- +/// vector.maskedstore +///---------------------------------------------------------------------------------------- + +func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) { + %0 = memref.alloc() : memref<3x8xi8> + %mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1> + vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8> + return +} + +// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative +// CHECK: memref.alloc() : memref<3x8xi8> +// CHECK-NOT: i32 diff --git a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp index eeb26d1876c1..7401e470ed4f 100644 --- a/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp @@ -78,7 +78,11 @@ struct TestEmulateNarrowTypePass IntegerType::get(ty.getContext(), arithComputeBitwidth)); }); - memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + // With the type converter enabled, we are effectively unable to write + // negative tests. This is a workaround specifically for negative tests. + if (!disableMemrefTypeConversion) + memref::populateMemRefNarrowTypeEmulationConversions(typeConverter); + ConversionTarget target(*ctx); target.addDynamicallyLegalOp([&typeConverter](Operation *op) { return typeConverter.isLegal(cast(op).getFunctionType()); @@ -109,6 +113,11 @@ struct TestEmulateNarrowTypePass Option arithComputeBitwidth{ *this, "arith-compute-bitwidth", llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)}; + + Option disableMemrefTypeConversion{ + *this, "skip-memref-type-conversion", + llvm::cl::desc("disable memref type conversion (to test failures)"), + llvm::cl::init(false)}; }; } // namespace