mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][vector] Handle corner cases in DropUnitDimsFromTransposeOp. (#102518)
da8778e499
breaks the lowering of vector.transpose that all the dimensions are unit
dimensions. The revision fixes the issue and adds a test.
---------
Signed-off-by: hanhanW <hanhan0912@gmail.com>
This commit is contained in:
@@ -1771,6 +1771,13 @@ struct DropUnitDimsFromTransposeOp final
|
||||
newPerm.push_back(idx - droppedDimsBefore[idx]);
|
||||
}
|
||||
|
||||
// Fixup for `newPerm`. The `sourceTypeWithoutUnitDims` could be vector<1xT>
|
||||
// type when the dimensions are unit dimensions. In this case, the newPerm
|
||||
// should be [0].
|
||||
if (newPerm.empty()) {
|
||||
newPerm.push_back(0);
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
// Drop the unit dims via shape_cast.
|
||||
auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
|
||||
|
||||
@@ -737,6 +737,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
|
||||
|
||||
// -----
|
||||
|
||||
func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
|
||||
%res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
|
||||
return %res : vector<1x1x1xf32>
|
||||
}
|
||||
// The `vec` is returned because there are other flattening patterns that fold
|
||||
// vector.shape_cast ops away.
|
||||
// CHECK-LABEL: func.func @transpose_with_all_unit_dims
|
||||
// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
|
||||
// CHECK-NEXT: return %[[VEC]]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
|
||||
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
|
||||
return %res : vector<4x3x2xf32>
|
||||
|
||||
Reference in New Issue
Block a user