mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[mlir][vector] Fix a crash in VectorToGPU (#113454)
This PR fixes a crash in `VectorToGPU` when the operand of `extOp` is a function argument, which cannot be retrieved using `getDefiningOp`. Fixes #107967.
This commit is contained in:
@@ -200,7 +200,9 @@ static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
|
||||
/// Return true if this integer extend op can be folded into a contract op.
|
||||
template <typename ExtOpTy>
|
||||
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
|
||||
if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
|
||||
auto transferReadOp =
|
||||
extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
|
||||
if (!transferReadOp)
|
||||
return false;
|
||||
return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
|
||||
}
|
||||
|
||||
@@ -517,3 +517,22 @@ func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf1
|
||||
vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
|
||||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
|
||||
// Ensure that no crash occurs when the predecessor operation
|
||||
// of `ext` is not `transfer_read`.
|
||||
|
||||
// CHECK-LABEL: func @test_unsupported
|
||||
// CHECK: vector.contract
|
||||
func.func @test_unsupported(%arg0: vector<4x4xi32>, %arg1: vector<4x4xi32>, %arg2: vector<4x4xi64>) -> vector<4x4xi64 > {
|
||||
%0 = arith.extui %arg0 : vector<4x4xi32> to vector<4x4xi64>
|
||||
%1 = arith.extui %arg1 : vector<4x4xi32> to vector<4x4xi64>
|
||||
%2 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
|
||||
%0, %1, %arg2 : vector<4x4xi64>, vector<4x4xi64> into vector<4x4xi64>
|
||||
return %2 : vector<4x4xi64>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user