From 3c46debe6b0143932d1fc3ac0d566c7f159f364d Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Thu, 20 Feb 2025 19:43:33 -0500 Subject: [PATCH] [MLIR] Fix 0-dimensional case of conversion of vector ops to GPU (#128075) This is a follow-up to #127844. That PR got vectors of arbitrary rank working, but I hadn't thought about the rank-0 case. Signed-off-by: Benoit Jacob --- mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 3 ++- .../test/Conversion/MathToROCDL/math-to-rocdl.mlir | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index c3b3a78abe7f..8b6c62ca2e36 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -624,7 +624,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::any_of(operandTypes, llvm::IsaPred)) { - VectorType vectorType = cast(op->getResultTypes()[0]); + VectorType vectorType = + cast(converter.convertType(op->getResultTypes()[0])); rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType, rewriter, converter)); return success(); diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 9448304f11db..313d7b086731 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -516,6 +516,20 @@ module { // ----- +module @test_module { + // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 + // CHECK-LABEL: func @math_sin_vector_0d + func.func @math_sin_vector_0d(%arg : vector) -> vector { + // CHECK: llvm.extractelement {{.*}} : vector<1xf16> + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + // CHECK: llvm.insertelement {{.*}} : vector<1xf16> + %result = math.sin %arg : vector + func.return %result : vector + } +} + +// ----- + module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_1d