mirror of
https://github.com/intel/llvm.git
synced 2026-02-05 22:17:23 +08:00
[mlir][gpu] Add support for f16 when lowering to nvvm intrinsics
Summary: The NVVM target only provides implementations for tanh etc. on f32 and f64 operands. To also support f16, we now insert operations to extend to f32 and truncate back to f16 around the intrinsic call. Differential Revision: https://reviews.llvm.org/D81473
This commit is contained in:
@@ -20,6 +20,9 @@ namespace mlir {
|
||||
/// depending on the element type that Op operates upon. The function
|
||||
/// declaration is added in case it was not added before.
|
||||
///
|
||||
/// If the input values are of f16 type, the value is first casted to f32, the
|
||||
/// function called and then the result casted back.
|
||||
///
|
||||
/// Example with NVVM:
|
||||
/// %exp_f32 = std.exp %arg_f32 : f32
|
||||
///
|
||||
@@ -44,21 +47,48 @@ public:
|
||||
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
|
||||
"expected single result op");
|
||||
|
||||
LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
|
||||
.template cast<LLVM::LLVMType>();
|
||||
LLVMType funcType = getFunctionType(resultType, operands);
|
||||
StringRef funcName = getFunctionName(resultType);
|
||||
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
|
||||
SourceOp>::value,
|
||||
"expected op with same operand and result types");
|
||||
|
||||
SmallVector<Value, 1> castedOperands;
|
||||
for (Value operand : operands)
|
||||
castedOperands.push_back(maybeCast(operand, rewriter));
|
||||
|
||||
LLVMType resultType =
|
||||
castedOperands.front().getType().cast<LLVM::LLVMType>();
|
||||
LLVMType funcType = getFunctionType(resultType, castedOperands);
|
||||
StringRef funcName = getFunctionName(funcType.getFunctionResultType());
|
||||
if (funcName.empty())
|
||||
return failure();
|
||||
|
||||
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
|
||||
auto callOp = rewriter.create<LLVM::CallOp>(
|
||||
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
|
||||
rewriter.replaceOp(op, {callOp.getResult(0)});
|
||||
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
|
||||
castedOperands);
|
||||
|
||||
if (resultType == operands.front().getType()) {
|
||||
rewriter.replaceOp(op, {callOp.getResult(0)});
|
||||
return success();
|
||||
}
|
||||
|
||||
Value truncated = rewriter.create<LLVM::FPTruncOp>(
|
||||
op->getLoc(), operands.front().getType(), callOp.getResult(0));
|
||||
rewriter.replaceOp(op, {truncated});
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
|
||||
LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
|
||||
if (!type.isHalfTy())
|
||||
return operand;
|
||||
|
||||
return rewriter.create<LLVM::FPExtOp>(
|
||||
operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
|
||||
operand);
|
||||
}
|
||||
|
||||
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
|
||||
ArrayRef<Value> operands) const {
|
||||
using LLVM::LLVMType;
|
||||
|
||||
@@ -219,12 +219,16 @@ gpu.module @test_module {
|
||||
// CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
|
||||
// CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
|
||||
// CHECK-LABEL: func @gpu_tanh
|
||||
func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
|
||||
func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
|
||||
%result16 = std.tanh %arg_f16 : f16
|
||||
// CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
|
||||
// CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
// CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
|
||||
%result32 = std.tanh %arg_f32 : f32
|
||||
// CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
|
||||
%result64 = std.tanh %arg_f64 : f64
|
||||
// CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
|
||||
std.return %result32, %result64 : f32, f64
|
||||
std.return %result16, %result32, %result64 : f16, f32, f64
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user