[mlir][AMDGPU] Add emulation pass for atomics on AMDGPU targets

Not all AMDGPU targets support all atomic operations. For example,
there are not atomic floating-point adds on the gfx10 series. Add a
pass to emulate these operations using a compare-and-swap loop, by
analogy to the generic atomicrmw rewrite in MemrefToLLVM.

This pass is named generally, as in the future we may have a
memref-to-amdgpu that translates constructs like atomicrmw fmax (which
doesn't generally exist in LLVM) to the relevant intrinsics, which may
themselves require emulation.

Since the AMDGPU dialect now has a pass that operates on it, the
dialect's directory structure is reorganized to match other similarly
complex dialects.

The pass should be run before amdgpu-to-rocdl if desired.

This commit also adds f64 support to atomic_fmax.

Depends on D148722

Reviewed By: nirvedhmeshram

Differential Revision: https://reviews.llvm.org/D148724
This commit is contained in:
Krzysztof Drewniak
2023-04-17 21:49:02 +00:00
parent 98c1104d41
commit cc4703745f
20 changed files with 382 additions and 34 deletions

View File

@@ -8,7 +8,7 @@
#ifndef MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_
#define MLIR_CONVERSION_AMDGPUTOROCDL_AMDGPUTOROCDL_H_
#include "mlir/Conversion/AMDGPUToROCDL/Chipset.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include <memory>
#include <string>

View File

@@ -1,12 +1,2 @@
add_mlir_dialect(AMDGPU amdgpu)
add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRAMDGPUEnumsGen)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu)
mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu)
add_public_tablegen_target(MLIRAMDGPUAttributesIncGen)
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -221,7 +221,7 @@ def AMDGPU_RawBufferAtomicFaddOp :
def AMDGPU_RawBufferAtomicFmaxOp :
AMDGPU_Op<"raw_buffer_atomic_fmax", [AllElementTypesMatch<["value", "memref"]>,
AttrSizedOperandSegments]>,
Arguments<(ins F32:$value,
Arguments<(ins AnyTypeOf<[F32, F64]>:$value,
Arg<AnyMemRef, "buffer to operate on", [MemRead, MemWrite]>:$memref,
Variadic<I32>:$indices,
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,

View File

@@ -11,22 +11,22 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMDGPU_AMDGPUDIALECT_H_
#define MLIR_DIALECT_AMDGPU_AMDGPUDIALECT_H_
#ifndef MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_
#define MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h.inc"
#include "mlir/Dialect/AMDGPU/AMDGPUEnums.h.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.h.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPU.h.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.h.inc"
#endif // MLIR_DIALECT_AMDGPU_AMDGPUDIALECT_H_
#endif // MLIR_DIALECT_AMDGPU_IR_AMDGPUDIALECT_H_

View File

@@ -0,0 +1,12 @@
add_mlir_dialect(AMDGPU amdgpu)
add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRAMDGPUEnumsGen)
set(LLVM_TARGET_DEFINITIONS AMDGPU.td)
mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu)
mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu)
add_public_tablegen_target(MLIRAMDGPUAttributesIncGen)

View File

@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name AMDGPU)
add_public_tablegen_target(MLIRAMDGPUTransformsIncGen)
add_dependencies(mlir-headers MLIRAMDGPUTransformsIncGen)
add_mlir_doc(Passes AMDGPUPasses ./ -gen-pass-doc)

View File

@@ -0,0 +1,33 @@
//===-- Passes.h - AMDGPU transformation pass declarations --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the transformation passes for the TOSA Dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_H_
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
class ConversionTarget;
namespace amdgpu {
#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS
#define GEN_PASS_REGISTRATION
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
void populateAmdgpuEmulateAtomicsPatterns(ConversionTarget &target,
RewritePatternSet &patterns,
Chipset chipset);
} // namespace amdgpu
} // namespace mlir
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_H_

View File

@@ -0,0 +1,33 @@
//===-- Passes.td - AMDGPU pass declarations ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the passes for the AMDGPU Dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
#define MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
include "mlir/Pass/PassBase.td"
def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> {
let summary = "Emulate atomic operations on chipsets that do not support them";
let description = [{
This pass rewrites any AMDGPU-specific atomic operation that is not supported
on the given `chipset` into a compare-and-swap loop.
}];
let dependentDialects = [
"cf::ControlFlowDialect",
"arith::ArithDialect",
];
let options = [Option<"chipset", "chipset", "std::string",
/*default=*/"\"gfx000\"",
"Chipset that these operations will run on">];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

View File

@@ -5,8 +5,8 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_AMDGPUTOROCDL_CHIPSET_H_
#define MLIR_CONVERSION_AMDGPUTOROCDL_CHIPSET_H_
#ifndef MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
#define MLIR_DIALECT_AMDGPU_UTILS_CHIPSET_H_
#include "mlir/Support/LogicalResult.h"

View File

@@ -14,7 +14,7 @@
#ifndef MLIR_INITALLDIALECTS_H_
#define MLIR_INITALLDIALECTS_H_
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"

View File

@@ -15,6 +15,7 @@
#define MLIR_INITALLPASSES_H_
#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Async/Passes.h"
@@ -56,6 +57,7 @@ inline void registerAllPasses() {
// Dialect passes
affine::registerAffinePasses();
amdgpu::registerAMDGPUPasses();
registerAsyncPasses();
arith::registerArithPasses();
bufferization::registerBufferizationPasses();

View File

@@ -10,7 +10,7 @@
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Pass/Pass.h"

View File

@@ -1,6 +1,5 @@
add_mlir_conversion_library(MLIRAMDGPUToROCDL
AMDGPUToROCDL.cpp
Chipset.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/AMDGPUToROCDL
@@ -16,6 +15,7 @@ add_mlir_conversion_library(MLIRAMDGPUToROCDL
MLIRLLVMDialect
MLIRROCDLDialect
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRPass
MLIRTransforms
)

View File

@@ -1 +1,3 @@
add_subdirectory(IR)
add_subdirectory(Transforms)
add_subdirectory(Utils)

View File

@@ -10,7 +10,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -30,16 +30,16 @@
using namespace mlir;
using namespace mlir::amdgpu;
#include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.cpp.inc"
void AMDGPUDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
>();
}
@@ -282,10 +282,10 @@ LogicalResult MFMAOp::verify() {
return success();
}
#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc"
#include "mlir/Dialect/AMDGPU/IR/AMDGPU.cpp.inc"

View File

@@ -0,0 +1,19 @@
add_mlir_dialect_library(MLIRAMDGPUTransforms
EmulateAtomics.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms
DEPENDS
MLIRAMDGPUTransformsIncGen
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRAMDGPUUtils
MLIRArithDialect
MLIRControlFlowDialect
MLIRIR
MLIRPass
MLIRTransforms
MLIRTransformUtils
)

View File

@@ -0,0 +1,189 @@
//===- EmulateAtomics.cpp - Emulate unsupported AMDGPU atomics ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir::amdgpu {
#define GEN_PASS_DEF_AMDGPUEMULATEATOMICSPASS
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc"
} // namespace mlir::amdgpu
using namespace mlir;
using namespace mlir::amdgpu;
namespace {
struct AmdgpuEmulateAtomicsPass
: public amdgpu::impl::AmdgpuEmulateAtomicsPassBase<
AmdgpuEmulateAtomicsPass> {
using AmdgpuEmulateAtomicsPassBase<
AmdgpuEmulateAtomicsPass>::AmdgpuEmulateAtomicsPassBase;
void runOnOperation() override;
};
template <typename AtomicOp, typename ArithOp>
struct RawBufferAtomicByCasPattern : public OpConversionPattern<AtomicOp> {
using OpConversionPattern<AtomicOp>::OpConversionPattern;
using Adaptor = typename AtomicOp::Adaptor;
LogicalResult
matchAndRewrite(AtomicOp atomicOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
namespace {
enum class DataArgAction : unsigned char {
Duplicate,
Drop,
};
} // namespace
// Fix up the fact that, when we're migrating from a general bugffer atomic
// to a load or to a CAS, the number of openrands, and thus the number of
// entries needed in operand_segment_sizes, needs to change. We use this method
// because we'd like to preserve unknown attributes on the atomic instead of
// discarding them.
static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
SmallVectorImpl<NamedAttribute> &newAttrs,
DataArgAction action) {
newAttrs.reserve(attrs.size());
for (NamedAttribute attr : attrs) {
if (attr.getName().getValue() != "operand_segment_sizes") {
newAttrs.push_back(attr);
continue;
}
auto segmentAttr = attr.getValue().cast<DenseI32ArrayAttr>();
MLIRContext *context = segmentAttr.getContext();
DenseI32ArrayAttr newSegments;
switch (action) {
case DataArgAction::Drop:
newSegments = DenseI32ArrayAttr::get(
context, segmentAttr.asArrayRef().drop_front());
break;
case DataArgAction::Duplicate: {
SmallVector<int32_t> newVals;
ArrayRef<int32_t> oldVals = segmentAttr.asArrayRef();
newVals.push_back(oldVals[0]);
newVals.append(oldVals.begin(), oldVals.end());
newSegments = DenseI32ArrayAttr::get(context, newVals);
break;
}
}
newAttrs.push_back(NamedAttribute(attr.getName(), newSegments));
}
}
template <typename AtomicOp, typename ArithOp>
LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
AtomicOp atomicOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = atomicOp.getLoc();
ArrayRef<NamedAttribute> origAttrs = atomicOp->getAttrs();
ValueRange operands = adaptor.getOperands();
Value data = operands.take_front()[0];
ValueRange invariantArgs = operands.drop_front();
Type dataType = data.getType();
SmallVector<NamedAttribute> loadAttrs;
patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
Value initialLoad =
rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
Block *currentBlock = rewriter.getInsertionBlock();
Block *afterAtomic =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
rewriter.setInsertionPointToEnd(loopBlock);
Value prevLoad = loopBlock->getArgument(0);
Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
SmallVector<NamedAttribute> cmpswapAttrs;
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
SmallVector<Value> cmpswapArgs = {operated, prevLoad};
cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
loc, dataType, cmpswapArgs, cmpswapAttrs);
// We care about exact bitwise equality here, so do some bitcasts.
// These will fold away during lowering to the ROCDL dialect, where
// an int->float bitcast is introduced to account for the fact that cmpswap
// only takes integer arguments.
Value prevLoadForCompare = prevLoad;
Value atomicResForCompare = atomicRes;
if (auto floatDataTy = dataType.dyn_cast<FloatType>()) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
atomicResForCompare =
rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
}
Value canLeave = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
loopBlock, atomicRes);
rewriter.replaceOp(atomicOp, {});
return success();
}
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
// gfx10 has no atomic adds.
if (chipset.majorVersion == 10 || chipset.majorVersion < 9 ||
(chipset.majorVersion == 9 && chipset.minorVersion < 0x08)) {
target.addIllegalOp<RawBufferAtomicFaddOp>();
}
// gfx9 has no to a very limited support for floating-point min and max.
if (chipset.majorVersion == 9) {
if (chipset.minorVersion >= 0x0a) {
// gfx90a supports f64 max (and min, but we don't have a min wrapper right
// now) but all other types need to be emulated.
target.addDynamicallyLegalOp<RawBufferAtomicFmaxOp>(
[](RawBufferAtomicFmaxOp op) -> bool {
return op.getValue().getType().isF64();
});
} else {
target.addIllegalOp<RawBufferAtomicFmaxOp>();
}
}
patterns
.add<RawBufferAtomicByCasPattern<RawBufferAtomicFaddOp, arith::AddFOp>,
RawBufferAtomicByCasPattern<RawBufferAtomicFmaxOp, arith::MaxFOp>>(
patterns.getContext());
}
void AmdgpuEmulateAtomicsPass::runOnOperation() {
Operation *op = getOperation();
FailureOr<Chipset> maybeChipset = Chipset::parse(chipset);
if (failed(maybeChipset)) {
emitError(op->getLoc(), "Invalid chipset name: " + chipset);
return signalPassFailure();
}
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
RewritePatternSet patterns(&ctx);
target.markUnknownOpDynamicallyLegal(
[](Operation *op) -> bool { return true; });
populateAmdgpuEmulateAtomicsPatterns(target, patterns, *maybeChipset);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
return signalPassFailure();
}

View File

@@ -0,0 +1,10 @@
add_mlir_dialect_library(MLIRAMDGPUUtils
Chipset.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Utils
LINK_LIBS PUBLIC
MLIRAMDGPUDialect
MLIRSupport
)

View File

@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/AMDGPUToROCDL/Chipset.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"

View File

@@ -0,0 +1,52 @@
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx90a %s | FileCheck %s --check-prefixes=CHECK,GFX9
// RUN: mlir-opt -split-input-file -amdgpu-emulate-atomics=chipset=gfx1030 %s | FileCheck %s --check-prefixes=CHECK,GFX10
// -----
func.func @atomic_fmax(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// CHECK: func @atomic_fmax
// CHECK-SAME: ([[val:%.+]]: f32, [[buffer:%.+]]: memref<?xf32>, [[idx:%.+]]: i32)
// CHECK: gpu.printf "Begin\0A"
// GFX10: amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} [[val]] -> [[buffer]][[[idx]]]
// GFX9: [[ld:%.+]] = amdgpu.raw_buffer_load {foo, indexOffset = 4 : i32} [[buffer]][[[idx]]]
// GFX9: cf.br [[loop:\^.+]]([[ld]] : f32)
// GFX9: [[loop]]([[arg:%.+]]: f32):
// GFX9: [[operated:%.+]] = arith.maxf [[val]], [[arg]]
// GFX9: [[atomicRes:%.+]] = amdgpu.raw_buffer_atomic_cmpswap {foo, indexOffset = 4 : i32} [[operated]], [[arg]] -> [[buffer]][[[idx]]]
// GFX9: [[argCast:%.+]] = arith.bitcast [[arg]] : f32 to i32
// GFX9: [[resCast:%.+]] = arith.bitcast [[atomicRes]] : f32 to i32
// GFX9: [[test:%.+]] = arith.cmpi eq, [[resCast]], [[argCast]]
// GFX9: cf.cond_br [[test]], [[post:\^.+]], [[loop]]([[atomicRes]] : f32)
// GFX9: [[post]]:
// CHECK-NEXT: gpu.printf "End\0A"
gpu.printf "Begin\n"
amdgpu.raw_buffer_atomic_fmax {foo, indexOffset = 4 : i32} %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
gpu.printf "End\n"
func.return
}
// -----
func.func @atomic_fmax_f64(%val: f64, %buffer: memref<?xf64>, %idx: i32) {
// CHECK: func @atomic_fmax_f64
// CHECK-SAME: ([[val:%.+]]: f64, [[buffer:%.+]]: memref<?xf64>, [[idx:%.+]]: i32)
// CHECK: gpu.printf "Begin\0A"
// GFX9: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// GFX10: amdgpu.raw_buffer_atomic_fmax [[val]] -> [[buffer]][[[idx]]]
// CHECK-NEXT: gpu.printf "End\0A"
gpu.printf "Begin\n"
amdgpu.raw_buffer_atomic_fmax %val -> %buffer[%idx] : f64 -> memref<?xf64>, i32
gpu.printf "End\n"
func.return
}
// -----
func.func @atomic_fadd(%val: f32, %buffer: memref<?xf32>, %idx: i32) {
// CHECK: func @atomic_fadd
// GFX9: amdgpu.raw_buffer_atomic_fadd
// GFX10: amdgpu.raw_buffer_load
// GFX10: amdgpu.raw_buffer_atomic_cmpswap
amdgpu.raw_buffer_atomic_fadd %val -> %buffer[%idx] : f32 -> memref<?xf32>, i32
func.return
}