mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][vector] Restrict narrow-type-emulation patterns (#115612)
All patterns in populateVectorNarrowTypeEmulationPatterns currently
assume a 1-D vector load/store rather than an n-D vector load/store.
This assumption is evident in ConvertVectorTransferRead, for example,
here (extracted from `ConvertVectorTransferRead`):
```cpp
auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements * scale, oldElementType), newRead);
```
Both invocations of `VectorType::get()` here generate a 1-D vector.
Attempts to use these patterns with more generic cases, such as 2-D
vectors, fail. For example, trying to cast the following 2-D case to
`i32`:
```mlir
func.func @vector_maskedload_2d_i8_negative(
%idx1: index,
%idx2: index,
%num_elems: index,
%passthru: vector<2x4xi8>) -> vector<2x4xi8> {
%0 = memref.alloc() : memref<3x4xi8>
%mask = vector.create_mask %num_elems, %num_elems : vector<2x4xi1>
%1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
return %1 : vector<2x4xi8>
}
```
For example, casting to i32 produces:
```bash
error: 'vector.bitcast' op failed to verify that all of {source, result} have same rank
%1 = vector.maskedload %0[%idx1, %idx2], %mask, %passthru :
^
```
Instead of reworking these patterns (that's going to require much more
effort), I’ve marked them as 1-D only and extended
"TestEmulateNarrowTypePass" with an option to disable the Memref type
converter - that's to be able to add negative tests (otherwise, the type
converter throws an error we can't really test for). While not ideal,
this workaround should suit a test pass.
This commit is contained in:
committed by
GitHub
parent
ba572abeb4
commit
e458434ebe
@@ -249,6 +249,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
|
||||
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// See #115653
|
||||
if (op.getValueToStore().getType().getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only 1-D vectors are supported ATM");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
Type oldElementType = op.getValueToStore().getType().getElementType();
|
||||
@@ -315,6 +320,11 @@ struct ConvertVectorMaskedStore final
|
||||
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// See #115653
|
||||
if (op.getValueToStore().getType().getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only 1-D vectors are supported ATM");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
Type oldElementType = op.getValueToStore().getType().getElementType();
|
||||
@@ -418,6 +428,11 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
|
||||
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// See #115653
|
||||
if (op.getVectorType().getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only 1-D vectors are supported ATM");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
Type oldElementType = op.getType().getElementType();
|
||||
@@ -517,6 +532,11 @@ struct ConvertVectorMaskedLoad final
|
||||
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// See #115653
|
||||
if (op.getVectorType().getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only 1-D vectors are supported ATM");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
|
||||
Type oldElementType = op.getType().getElementType();
|
||||
@@ -674,6 +694,11 @@ struct ConvertVectorTransferRead final
|
||||
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
// See #115653
|
||||
if (op.getVectorType().getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only 1-D vectors are supported ATM");
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
|
||||
Type oldElementType = op.getType().getElementType();
|
||||
|
||||
111
mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
Normal file
111
mlir/test/Dialect/Vector/emulate-narrow-type-unsupported.mlir
Normal file
@@ -0,0 +1,111 @@
|
||||
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=32 skip-memref-type-conversion" --split-input-file %s | FileCheck %s
|
||||
|
||||
// These tests mimic tests from vector-narrow-type.mlir, but load/store 2-D
|
||||
// insted of 1-D vectors. That's currently not supported.
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.load
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_load_2d_i8_negative(%arg1: index, %arg2: index) -> vector<2x4xi8> {
|
||||
%0 = memref.alloc() : memref<3x4xi8>
|
||||
%1 = vector.load %0[%arg1, %arg2] : memref<3x4xi8>, vector<2x4xi8>
|
||||
return %1 : vector<2x4xi8>
|
||||
}
|
||||
|
||||
// No support for loading 2D vectors - expect no conversions
|
||||
// CHECK-LABEL: func @vector_load_2d_i8_negative
|
||||
// CHECK: memref.alloc() : memref<3x4xi8>
|
||||
// CHECK-NOT: i32
|
||||
|
||||
// -----
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.transfer_read
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_transfer_read_2d_i4_negative(%arg1: index, %arg2: index) -> vector<2x8xi4> {
|
||||
%c0 = arith.constant 0 : i4
|
||||
%0 = memref.alloc() : memref<3x8xi4>
|
||||
%1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true, true]} :
|
||||
memref<3x8xi4>, vector<2x8xi4>
|
||||
return %1 : vector<2x8xi4>
|
||||
}
|
||||
// CHECK-LABEL: func @vector_transfer_read_2d_i4_negative
|
||||
// CHECK: memref.alloc() : memref<3x8xi4>
|
||||
// CHECK-NOT: i32
|
||||
|
||||
// -----
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.maskedload
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_maskedload_2d_i8_negative(%arg1: index, %arg2: index, %arg3: index, %passthru: vector<2x4xi8>) -> vector<2x4xi8> {
|
||||
%0 = memref.alloc() : memref<3x4xi8>
|
||||
%mask = vector.create_mask %arg3, %arg3 : vector<2x4xi1>
|
||||
%1 = vector.maskedload %0[%arg1, %arg2], %mask, %passthru :
|
||||
memref<3x4xi8>, vector<2x4xi1>, vector<2x4xi8> into vector<2x4xi8>
|
||||
return %1 : vector<2x4xi8>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_maskedload_2d_i8_negative
|
||||
// CHECK: memref.alloc() : memref<3x4xi8>
|
||||
// CHECK-NOT: i32
|
||||
|
||||
// -----
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.extract -> vector.masked_load
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_extract_maskedload_2d_i4_negative(%arg1: index) -> vector<8x8x16xi4> {
|
||||
%0 = memref.alloc() : memref<8x8x16xi4>
|
||||
%c0 = arith.constant 0 : index
|
||||
%c16 = arith.constant 16 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%cst_1 = arith.constant dense<0> : vector<8x8x16xi4>
|
||||
%cst_2 = arith.constant dense<0> : vector<8x16xi4>
|
||||
%27 = vector.create_mask %c8, %arg1, %c16 : vector<8x8x16xi1>
|
||||
%48 = vector.extract %27[0] : vector<8x16xi1> from vector<8x8x16xi1>
|
||||
%50 = vector.maskedload %0[%c0, %c0, %c0], %48, %cst_2 : memref<8x8x16xi4>, vector<8x16xi1>, vector<8x16xi4> into vector<8x16xi4>
|
||||
%63 = vector.insert %50, %cst_1 [0] : vector<8x16xi4> into vector<8x8x16xi4>
|
||||
return %63 : vector<8x8x16xi4>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_extract_maskedload_2d_i4_negative
|
||||
// CHECK: memref.alloc() : memref<8x8x16xi4>
|
||||
// CHECK-NOT: i32
|
||||
|
||||
// -----
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.store
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_store_2d_i8_negative(%arg0: vector<2x8xi8>, %arg1: index, %arg2: index) {
|
||||
%0 = memref.alloc() : memref<4x8xi8>
|
||||
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<2x8xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_store_2d_i8_negative
|
||||
// CHECK: memref.alloc() : memref<4x8xi8>
|
||||
// CHECK-NOT: i32
|
||||
|
||||
// -----
|
||||
|
||||
///----------------------------------------------------------------------------------------
|
||||
/// vector.maskedstore
|
||||
///----------------------------------------------------------------------------------------
|
||||
|
||||
func.func @vector_maskedstore_2d_i8_negative(%arg0: index, %arg1: index, %arg2: index, %value: vector<2x8xi8>) {
|
||||
%0 = memref.alloc() : memref<3x8xi8>
|
||||
%mask = vector.create_mask %arg2, %arg2 : vector<2x8xi1>
|
||||
vector.maskedstore %0[%arg0, %arg1], %mask, %value : memref<3x8xi8>, vector<2x8xi1>, vector<2x8xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_maskedstore_2d_i8_negative
|
||||
// CHECK: memref.alloc() : memref<3x8xi8>
|
||||
// CHECK-NOT: i32
|
||||
@@ -78,7 +78,11 @@ struct TestEmulateNarrowTypePass
|
||||
IntegerType::get(ty.getContext(), arithComputeBitwidth));
|
||||
});
|
||||
|
||||
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
|
||||
// With the type converter enabled, we are effectively unable to write
|
||||
// negative tests. This is a workaround specifically for negative tests.
|
||||
if (!disableMemrefTypeConversion)
|
||||
memref::populateMemRefNarrowTypeEmulationConversions(typeConverter);
|
||||
|
||||
ConversionTarget target(*ctx);
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&typeConverter](Operation *op) {
|
||||
return typeConverter.isLegal(cast<func::FuncOp>(op).getFunctionType());
|
||||
@@ -109,6 +113,11 @@ struct TestEmulateNarrowTypePass
|
||||
Option<unsigned> arithComputeBitwidth{
|
||||
*this, "arith-compute-bitwidth",
|
||||
llvm::cl::desc("arith computation bit width"), llvm::cl::init(4)};
|
||||
|
||||
Option<bool> disableMemrefTypeConversion{
|
||||
*this, "skip-memref-type-conversion",
|
||||
llvm::cl::desc("disable memref type conversion (to test failures)"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
Reference in New Issue
Block a user