[mlir][gpu] Add support for f16 when lowering to nvvm intrinsics

Summary:
The NVVM target only provides implementations for tanh etc. on f32 and
f64 operands. To also support f16, we now insert operations to extend to f32
and truncate back to f16 around the intrinsic call.

Differential Revision: https://reviews.llvm.org/D81473
This commit is contained in:
Stephan Herhut
2020-06-09 17:20:53 +02:00
parent b7d369280b
commit 2c8afe1298
2 changed files with 42 additions and 8 deletions

View File

@@ -20,6 +20,9 @@ namespace mlir {
/// depending on the element type that Op operates upon. The function
/// declaration is added in case it was not added before.
///
/// If the input values are of f16 type, the value is first casted to f32, the
/// function called and then the result casted back.
///
/// Example with NVVM:
/// %exp_f32 = std.exp %arg_f32 : f32
///
@@ -44,21 +47,48 @@ public:
std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
"expected single result op");
LLVMType resultType = typeConverter.convertType(op->getResult(0).getType())
.template cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, operands);
StringRef funcName = getFunctionName(resultType);
static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
SourceOp>::value,
"expected op with same operand and result types");
SmallVector<Value, 1> castedOperands;
for (Value operand : operands)
castedOperands.push_back(maybeCast(operand, rewriter));
LLVMType resultType =
castedOperands.front().getType().cast<LLVM::LLVMType>();
LLVMType funcType = getFunctionType(resultType, castedOperands);
StringRef funcName = getFunctionName(funcType.getFunctionResultType());
if (funcName.empty())
return failure();
LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op);
auto callOp = rewriter.create<LLVM::CallOp>(
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands);
rewriter.replaceOp(op, {callOp.getResult(0)});
op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp),
castedOperands);
if (resultType == operands.front().getType()) {
rewriter.replaceOp(op, {callOp.getResult(0)});
return success();
}
Value truncated = rewriter.create<LLVM::FPTruncOp>(
op->getLoc(), operands.front().getType(), callOp.getResult(0));
rewriter.replaceOp(op, {truncated});
return success();
}
private:
Value maybeCast(Value operand, PatternRewriter &rewriter) const {
LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
if (!type.isHalfTy())
return operand;
return rewriter.create<LLVM::FPExtOp>(
operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()),
operand);
}
LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
ArrayRef<Value> operands) const {
using LLVM::LLVMType;

View File

@@ -219,12 +219,16 @@ gpu.module @test_module {
// CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float
// CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double
// CHECK-LABEL: func @gpu_tanh
func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
%result16 = std.tanh %arg_f16 : f16
// CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float
// CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
// CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half
%result32 = std.tanh %arg_f32 : f32
// CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float
%result64 = std.tanh %arg_f64 : f64
// CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double
std.return %result32, %result64 : f32, f64
std.return %result16, %result32, %result64 : f16, f32, f64
}
}