mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 06:40:01 +08:00
[mlir][spirv] Add a pass to unify aliased resource variables
In SPIR-V, resources are represented as global variables that are bound to certain descriptor. SPIR-V requires those global variables to be declared as aliased if multiple ones are bound to the same slot. Such aliased decorations can cause issues for transcompilers like SPIRV-Cross when converting to source shading languages like MSL. So this commit adds a pass to perform analysis of aliased resources and see if we can unify them into one. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D119872
This commit is contained in:
@@ -385,7 +385,7 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
||||
OptionalAttr<FlatSymbolRefAttr>:$initializer,
|
||||
OptionalAttr<I32Attr>:$location,
|
||||
OptionalAttr<I32Attr>:$binding,
|
||||
OptionalAttr<I32Attr>:$descriptorSet,
|
||||
OptionalAttr<I32Attr>:$descriptor_set,
|
||||
OptionalAttr<StrAttr>:$builtin
|
||||
);
|
||||
|
||||
|
||||
@@ -55,6 +55,11 @@ std::unique_ptr<OperationPass<spirv::ModuleOp>> createLowerABIAttributesPass();
|
||||
/// spv.CompositeInsert into spv.CompositeConstruct.
|
||||
std::unique_ptr<OperationPass<spirv::ModuleOp>> createRewriteInsertsPass();
|
||||
|
||||
/// Creates an operation pass that unifies access of multiple aliased resources
|
||||
/// into access of one single resource.
|
||||
std::unique_ptr<OperationPass<spirv::ModuleOp>>
|
||||
createUnifyAliasedResourcePass();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -28,6 +28,13 @@ def SPIRVRewriteInsertsPass : Pass<"spirv-rewrite-inserts", "spirv::ModuleOp"> {
|
||||
let constructor = "mlir::spirv::createRewriteInsertsPass()";
|
||||
}
|
||||
|
||||
def SPIRVUnifyAliasedResourcePass
|
||||
: Pass<"spirv-unify-aliased-resource", "spirv::ModuleOp"> {
|
||||
let summary = "Unify access of multiple aliased resources into access of one "
|
||||
"single resource";
|
||||
let constructor = "mlir::spirv::createUnifyAliasedResourcePass()";
|
||||
}
|
||||
|
||||
def SPIRVUpdateVCE : Pass<"spirv-update-vce", "spirv::ModuleOp"> {
|
||||
let summary = "Deduce and attach minimal (version, capabilities, extensions) "
|
||||
"requirements to spv.module ops";
|
||||
|
||||
@@ -3,6 +3,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||
LowerABIAttributesPass.cpp
|
||||
RewriteInsertsPass.cpp
|
||||
SPIRVConversion.cpp
|
||||
UnifyAliasedResourcePass.cpp
|
||||
UpdateVCEPass.cpp
|
||||
)
|
||||
|
||||
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRSPIRVTransforms
|
||||
DecorateCompositeTypeLayoutPass.cpp
|
||||
LowerABIAttributesPass.cpp
|
||||
RewriteInsertsPass.cpp
|
||||
UnifyAliasedResourcePass.cpp
|
||||
UpdateVCEPass.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
|
||||
452
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
Normal file
452
mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
Normal file
@@ -0,0 +1,452 @@
|
||||
//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements a pass that unifies access of multiple aliased resources
|
||||
// into access of one single resource.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Pass/AnalysisManager.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <algorithm>
|
||||
|
||||
#define DEBUG_TYPE "spirv-unify-aliased-resource"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
|
||||
using AliasedResourceMap =
|
||||
DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
|
||||
|
||||
/// Collects all aliased resources in the given SPIR-V `moduleOp`.
|
||||
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
|
||||
AliasedResourceMap aliasedResoruces;
|
||||
moduleOp->walk([&aliasedResoruces](spirv::GlobalVariableOp varOp) {
|
||||
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
|
||||
Optional<uint32_t> set = varOp.descriptor_set();
|
||||
Optional<uint32_t> binding = varOp.binding();
|
||||
if (set && binding)
|
||||
aliasedResoruces[{*set, *binding}].push_back(varOp);
|
||||
}
|
||||
});
|
||||
return aliasedResoruces;
|
||||
}
|
||||
|
||||
/// Returns the element type if the given `type` is a runtime array resource:
|
||||
/// `!spv.ptr<!spv.struct<!spv.rtarray<...>>>`. Returns null type otherwise.
|
||||
static Type getRuntimeArrayElementType(Type type) {
|
||||
auto ptrType = type.dyn_cast<spirv::PointerType>();
|
||||
if (!ptrType)
|
||||
return {};
|
||||
|
||||
auto structType = ptrType.getPointeeType().dyn_cast<spirv::StructType>();
|
||||
if (!structType || structType.getNumElements() != 1)
|
||||
return {};
|
||||
|
||||
auto rtArrayType =
|
||||
structType.getElementType(0).dyn_cast<spirv::RuntimeArrayType>();
|
||||
if (!rtArrayType)
|
||||
return {};
|
||||
|
||||
return rtArrayType.getElementType();
|
||||
}
|
||||
|
||||
/// Returns true if all `types`, which can either be scalar or vector types,
|
||||
/// have the same bitwidth base scalar type.
|
||||
static bool hasSameBitwidthScalarType(ArrayRef<spirv::SPIRVType> types) {
|
||||
SmallVector<int64_t> scalarTypes;
|
||||
scalarTypes.reserve(types.size());
|
||||
for (spirv::SPIRVType type : types) {
|
||||
assert(type.isScalarOrVector());
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
scalarTypes.push_back(
|
||||
vectorType.getElementType().getIntOrFloatBitWidth());
|
||||
else
|
||||
scalarTypes.push_back(type.getIntOrFloatBitWidth());
|
||||
}
|
||||
return llvm::is_splat(scalarTypes);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// A class for analyzing aliased resources.
|
||||
///
|
||||
/// Resources are expected to be spv.GlobalVarible that has a descriptor set and
|
||||
/// binding number. Such resources are of the type `!spv.ptr<!spv.struct<...>>`
|
||||
/// per Vulkan requirements.
|
||||
///
|
||||
/// Right now, we only support the case that there is a single runtime array
|
||||
/// inside the struct.
|
||||
class ResourceAliasAnalysis {
|
||||
public:
|
||||
explicit ResourceAliasAnalysis(Operation *);
|
||||
|
||||
/// Returns true if the given `op` can be rewritten to use a canonical
|
||||
/// resource.
|
||||
bool shouldUnify(Operation *op) const;
|
||||
|
||||
/// Returns all descriptors and their corresponding aliased resources.
|
||||
const AliasedResourceMap &getResourceMap() const { return resourceMap; }
|
||||
|
||||
/// Returns the canonical resource for the given descriptor/variable.
|
||||
spirv::GlobalVariableOp
|
||||
getCanonicalResource(const Descriptor &descriptor) const;
|
||||
spirv::GlobalVariableOp
|
||||
getCanonicalResource(spirv::GlobalVariableOp varOp) const;
|
||||
|
||||
/// Returns the element type for the given variable.
|
||||
spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
|
||||
|
||||
private:
|
||||
/// Given the descriptor and aliased resources bound to it, analyze whether we
|
||||
/// can unify them and record if so.
|
||||
void recordIfUnifiable(const Descriptor &descriptor,
|
||||
ArrayRef<spirv::GlobalVariableOp> resources);
|
||||
|
||||
/// Mapping from a descriptor to all aliased resources bound to it.
|
||||
AliasedResourceMap resourceMap;
|
||||
|
||||
/// Mapping from a descriptor to the chosen canonical resource.
|
||||
DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
|
||||
|
||||
/// Mapping from an aliased resource to its descriptor.
|
||||
DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
|
||||
|
||||
/// Mapping from an aliased resource to its element (scalar/vector) type.
|
||||
DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
|
||||
// Collect all aliased resources first and put them into different sets
|
||||
// according to the descriptor.
|
||||
AliasedResourceMap aliasedResoruces =
|
||||
collectAliasedResources(cast<spirv::ModuleOp>(root));
|
||||
|
||||
// For each resource set, analyze whether we can unify; if so, try to identify
|
||||
// a canonical resource, whose element type has the largest bitwidth.
|
||||
for (const auto &descriptorResoruce : aliasedResoruces) {
|
||||
recordIfUnifiable(descriptorResoruce.first, descriptorResoruce.second);
|
||||
}
|
||||
}
|
||||
|
||||
bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
|
||||
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
|
||||
auto canonicalOp = getCanonicalResource(varOp);
|
||||
return canonicalOp && varOp != canonicalOp;
|
||||
}
|
||||
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
|
||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto *varOp = SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable());
|
||||
return shouldUnify(varOp);
|
||||
}
|
||||
|
||||
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
|
||||
return shouldUnify(acOp.base_ptr().getDefiningOp());
|
||||
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
|
||||
return shouldUnify(loadOp.ptr().getDefiningOp());
|
||||
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
|
||||
return shouldUnify(storeOp.ptr().getDefiningOp());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
|
||||
const Descriptor &descriptor) const {
|
||||
auto varIt = canonicalResourceMap.find(descriptor);
|
||||
if (varIt == canonicalResourceMap.end())
|
||||
return {};
|
||||
return varIt->second;
|
||||
}
|
||||
|
||||
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
|
||||
spirv::GlobalVariableOp varOp) const {
|
||||
auto descriptorIt = descriptorMap.find(varOp);
|
||||
if (descriptorIt == descriptorMap.end())
|
||||
return {};
|
||||
return getCanonicalResource(descriptorIt->second);
|
||||
}
|
||||
|
||||
spirv::SPIRVType
|
||||
ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
|
||||
auto it = elementTypeMap.find(varOp);
|
||||
if (it == elementTypeMap.end())
|
||||
return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
void ResourceAliasAnalysis::recordIfUnifiable(
|
||||
const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
|
||||
// Collect the element types and byte counts for all resources in the
|
||||
// current set.
|
||||
SmallVector<spirv::SPIRVType> elementTypes;
|
||||
SmallVector<int64_t> numBytes;
|
||||
|
||||
for (spirv::GlobalVariableOp resource : resources) {
|
||||
Type elementType = getRuntimeArrayElementType(resource.type());
|
||||
if (!elementType)
|
||||
return; // Unexpected resource variable type.
|
||||
|
||||
auto type = elementType.cast<spirv::SPIRVType>();
|
||||
if (!type.isScalarOrVector())
|
||||
return; // Unexpected resource element type.
|
||||
|
||||
if (auto vectorType = type.dyn_cast<VectorType>())
|
||||
if (vectorType.getNumElements() % 2 != 0)
|
||||
return; // Odd-sized vector has special layout requirements.
|
||||
|
||||
Optional<int64_t> count = type.getSizeInBytes();
|
||||
if (!count)
|
||||
return;
|
||||
|
||||
elementTypes.push_back(type);
|
||||
numBytes.push_back(*count);
|
||||
}
|
||||
|
||||
// Make sure base scalar types have the same bitwdith, so that we don't need
|
||||
// to handle extracting components for now.
|
||||
if (!hasSameBitwidthScalarType(elementTypes))
|
||||
return;
|
||||
|
||||
// Make sure that the canonical resource's bitwidth is divisible by others.
|
||||
// With out this, we cannot properly adjust the index later.
|
||||
auto *maxCount = std::max_element(numBytes.begin(), numBytes.end());
|
||||
if (llvm::any_of(numBytes, [maxCount](int64_t count) {
|
||||
return *maxCount % count != 0;
|
||||
}))
|
||||
return;
|
||||
|
||||
spirv::GlobalVariableOp canonicalResource =
|
||||
resources[std::distance(numBytes.begin(), maxCount)];
|
||||
|
||||
// Update internal data structures for later use.
|
||||
resourceMap[descriptor].assign(resources.begin(), resources.end());
|
||||
canonicalResourceMap[descriptor] = canonicalResource;
|
||||
for (const auto &resource : llvm::enumerate(resources)) {
|
||||
descriptorMap[resource.value()] = descriptor;
|
||||
elementTypeMap[resource.value()] = elementTypes[resource.index()];
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename OpTy>
|
||||
class ConvertAliasResoruce : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
ConvertAliasResoruce(const ResourceAliasAnalysis &analysis,
|
||||
MLIRContext *context, PatternBenefit benefit = 1)
|
||||
: OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
|
||||
|
||||
protected:
|
||||
const ResourceAliasAnalysis &analysis;
|
||||
};
|
||||
|
||||
struct ConvertVariable : public ConvertAliasResoruce<spirv::GlobalVariableOp> {
|
||||
using ConvertAliasResoruce::ConvertAliasResoruce;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Just remove the aliased resource. Users will be rewritten to use the
|
||||
// canonical one.
|
||||
rewriter.eraseOp(varOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertAddressOf : public ConvertAliasResoruce<spirv::AddressOfOp> {
|
||||
using ConvertAliasResoruce::ConvertAliasResoruce;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Rewrite the AddressOf op to get the address of the canoncical resource.
|
||||
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertAccessChain : public ConvertAliasResoruce<spirv::AccessChainOp> {
|
||||
using ConvertAliasResoruce::ConvertAliasResoruce;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto addressOp = acOp.base_ptr().getDefiningOp<spirv::AddressOfOp>();
|
||||
if (!addressOp)
|
||||
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
|
||||
|
||||
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
|
||||
auto srcVarOp = cast<spirv::GlobalVariableOp>(
|
||||
SymbolTable::lookupSymbolIn(moduleOp, addressOp.variable()));
|
||||
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
|
||||
|
||||
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
|
||||
spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
|
||||
|
||||
if ((srcElemType == dstElemType) ||
|
||||
(srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat())) {
|
||||
// We have the same bitwidth for source and destination element types.
|
||||
// Thie indices keep the same.
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
acOp, adaptor.base_ptr(), adaptor.indices());
|
||||
return success();
|
||||
}
|
||||
|
||||
Location loc = acOp.getLoc();
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
|
||||
if (srcElemType.isIntOrFloat() && dstElemType.isa<VectorType>()) {
|
||||
// The source indices are for a buffer with scalar element types. Rewrite
|
||||
// them into a buffer with vector element types. We need to scale the last
|
||||
// index for the vector as a whole, then add one level of index for inside
|
||||
// the vector.
|
||||
int ratio = *dstElemType.getSizeInBytes() / *srcElemType.getSizeInBytes();
|
||||
auto ratioValue = rewriter.create<spirv::ConstantOp>(
|
||||
loc, i32Type, rewriter.getI32IntegerAttr(ratio));
|
||||
|
||||
auto indices = llvm::to_vector<4>(acOp.indices());
|
||||
Value oldIndex = indices.back();
|
||||
indices.back() =
|
||||
rewriter.create<spirv::SDivOp>(loc, i32Type, oldIndex, ratioValue);
|
||||
indices.push_back(
|
||||
rewriter.create<spirv::SModOp>(loc, i32Type, oldIndex, ratioValue));
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
acOp, adaptor.base_ptr(), indices);
|
||||
return success();
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(acOp, "unsupported src/dst types");
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertLoad : public ConvertAliasResoruce<spirv::LoadOp> {
|
||||
using ConvertAliasResoruce::ConvertAliasResoruce;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcElemType =
|
||||
loadOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
auto dstElemType =
|
||||
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(loadOp, "not scalar type");
|
||||
|
||||
Location loc = loadOp.getLoc();
|
||||
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.ptr());
|
||||
if (srcElemType == dstElemType) {
|
||||
rewriter.replaceOp(loadOp, newLoadOp->getResults());
|
||||
} else {
|
||||
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
|
||||
newLoadOp.value());
|
||||
rewriter.replaceOp(loadOp, castOp->getResults());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ConvertStore : public ConvertAliasResoruce<spirv::StoreOp> {
|
||||
using ConvertAliasResoruce::ConvertAliasResoruce;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcElemType =
|
||||
storeOp.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
auto dstElemType =
|
||||
adaptor.ptr().getType().cast<spirv::PointerType>().getPointeeType();
|
||||
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
|
||||
|
||||
Location loc = storeOp.getLoc();
|
||||
Value value = adaptor.value();
|
||||
if (srcElemType != dstElemType)
|
||||
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
|
||||
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.ptr(), value,
|
||||
storeOp->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class UnifyAliasedResourcePass final
|
||||
: public SPIRVUnifyAliasedResourcePassBase<UnifyAliasedResourcePass> {
|
||||
public:
|
||||
void runOnOperation() override;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void UnifyAliasedResourcePass::runOnOperation() {
|
||||
spirv::ModuleOp moduleOp = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
// Analyze aliased resources first.
|
||||
ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
|
||||
spirv::AccessChainOp, spirv::LoadOp,
|
||||
spirv::StoreOp>(
|
||||
[&analysis](Operation *op) { return !analysis.shouldUnify(op); });
|
||||
target.addLegalDialect<spirv::SPIRVDialect>();
|
||||
|
||||
// Run patterns to rewrite usages of non-canonical resources.
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
|
||||
ConvertLoad, ConvertStore>(analysis, context);
|
||||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// Drop aliased attribute if we only have one single bound resource for a
|
||||
// descriptor. We need to re-collect the map here given in the above the
|
||||
// conversion is best effort; certain sets may not be converted.
|
||||
AliasedResourceMap resourceMap =
|
||||
collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
|
||||
for (const auto &dr : resourceMap) {
|
||||
const auto &resources = dr.second;
|
||||
if (resources.size() == 1)
|
||||
resources.front()->removeAttr("aliased");
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
|
||||
spirv::createUnifyAliasedResourcePass() {
|
||||
return std::make_unique<UnifyAliasedResourcePass>();
|
||||
}
|
||||
215
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Normal file
215
mlir/test/Dialect/SPIRV/Transforms/unify-aliased-resource.mlir
Normal file
@@ -0,0 +1,215 @@
|
||||
// RUN: mlir-opt -split-input-file -spirv-unify-aliased-resource %s -o - | FileCheck %s
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @load_store_scalar(%index: i32) -> f32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%value = spv.Load "StorageBuffer" %ac : f32
|
||||
spv.Store "StorageBuffer" %ac, %value : f32
|
||||
spv.ReturnValue %value : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK-NOT: @var01s
|
||||
// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
// CHECK-NOT: @var01s
|
||||
|
||||
// CHECK: spv.func @load_store_scalar(%[[INDEX:.+]]: i32)
|
||||
// CHECK-DAG: %[[C0:.+]] = spv.Constant 0 : i32
|
||||
// CHECK-DAG: %[[C4:.+]] = spv.Constant 4 : i32
|
||||
// CHECK-DAG: %[[ADDR:.+]] = spv.mlir.addressof @var01v
|
||||
// CHECK: %[[DIV:.+]] = spv.SDiv %[[INDEX]], %[[C4]] : i32
|
||||
// CHECK: %[[MOD:.+]] = spv.SMod %[[INDEX]], %[[C4]] : i32
|
||||
// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%[[C0]], %[[DIV]], %[[MOD]]]
|
||||
// CHECK: spv.Load "StorageBuffer" %[[AC]]
|
||||
// CHECK: spv.Store "StorageBuffer" %[[AC]]
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @multiple_uses(%i0: i32, %i1: i32) -> f32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac0 = spv.AccessChain %addr[%c0, %i0] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val0 = spv.Load "StorageBuffer" %ac0 : f32
|
||||
%ac1 = spv.AccessChain %addr[%c0, %i1] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val1 = spv.Load "StorageBuffer" %ac1 : f32
|
||||
%value = spv.FAdd %val0, %val1 : f32
|
||||
spv.ReturnValue %value : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK-NOT: @var01s
|
||||
// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
// CHECK-NOT: @var01s
|
||||
|
||||
// CHECK: spv.func @multiple_uses
|
||||
// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v
|
||||
// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
// CHECK: spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<3xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @vector3(%index: i32) -> f32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%value = spv.Load "StorageBuffer" %ac : f32
|
||||
spv.ReturnValue %value : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK: spv.GlobalVariable @var01s bind(0, 1) {aliased}
|
||||
// CHECK: spv.GlobalVariable @var01v bind(0, 1) {aliased}
|
||||
// CHECK: spv.func @vector3
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(1, 0) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @not_aliased(%index: i32) -> f32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%value = spv.Load "StorageBuffer" %ac : f32
|
||||
spv.Store "StorageBuffer" %ac, %value : f32
|
||||
spv.ReturnValue %value : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK: spv.GlobalVariable @var01s bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
// CHECK: spv.GlobalVariable @var01v bind(1, 0) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
// CHECK: spv.func @not_aliased
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01s_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v_1 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @multiple_aliases(%index: i32) -> f32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
|
||||
%addr0 = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val0 = spv.Load "StorageBuffer" %ac0 : f32
|
||||
|
||||
%addr1 = spv.mlir.addressof @var01s_1 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val1 = spv.Load "StorageBuffer" %ac1 : f32
|
||||
|
||||
%addr2 = spv.mlir.addressof @var01v_1 : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
%ac2 = spv.AccessChain %addr2[%c0, %index, %c0] : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>, i32, i32, i32
|
||||
%val2 = spv.Load "StorageBuffer" %ac2 : f32
|
||||
|
||||
%add0 = spv.FAdd %val0, %val1 : f32
|
||||
%add1 = spv.FAdd %add0, %val2 : f32
|
||||
spv.ReturnValue %add1 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK-NOT: @var01s
|
||||
// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
// CHECK-NOT: @var01v_1
|
||||
|
||||
// CHECK: spv.func @multiple_aliases
|
||||
// CHECK: %[[ADDR0:.+]] = spv.mlir.addressof @var01v :
|
||||
// CHECK: spv.AccessChain %[[ADDR0]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
// CHECK: %[[ADDR1:.+]] = spv.mlir.addressof @var01v :
|
||||
// CHECK: spv.AccessChain %[[ADDR1]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
// CHECK: %[[ADDR2:.+]] = spv.mlir.addressof @var01v :
|
||||
// CHECK: spv.AccessChain %[[ADDR2]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s_i32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01s_f32 bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @different_scalar_type(%index: i32, %val1: f32) -> i32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
|
||||
%addr0 = spv.mlir.addressof @var01s_i32 : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
|
||||
%ac0 = spv.AccessChain %addr0[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val0 = spv.Load "StorageBuffer" %ac0 : i32
|
||||
|
||||
%addr1 = spv.mlir.addressof @var01s_f32 : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>
|
||||
%ac1 = spv.AccessChain %addr1[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<f32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
spv.Store "StorageBuffer" %ac1, %val1 : f32
|
||||
|
||||
spv.ReturnValue %val0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK-NOT: @var01s_f32
|
||||
// CHECK: spv.GlobalVariable @var01s_i32 bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
|
||||
// CHECK-NOT: @var01s_f32
|
||||
|
||||
// CHECK: spv.func @different_scalar_type(%[[INDEX:.+]]: i32, %[[VAL1:.+]]: f32)
|
||||
|
||||
// CHECK: %[[IADDR:.+]] = spv.mlir.addressof @var01s_i32
|
||||
// CHECK: %[[IAC:.+]] = spv.AccessChain %[[IADDR]][%{{.+}}, %[[INDEX]]]
|
||||
// CHECK: spv.Load "StorageBuffer" %[[IAC]] : i32
|
||||
|
||||
// CHECK: %[[FADDR:.+]] = spv.mlir.addressof @var01s_i32
|
||||
// CHECK: %[[FAC:.+]] = spv.AccessChain %[[FADDR]][%cst0_i32, %[[INDEX]]]
|
||||
// CHECK: %[[CAST:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
|
||||
// CHECK: spv.Store "StorageBuffer" %[[FAC]], %[[CAST]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
spv.GlobalVariable @var01s bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
|
||||
spv.GlobalVariable @var01v bind(0, 1) {aliased} : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
|
||||
spv.func @different_scalar_type(%index: i32, %val0: i32) -> i32 "None" {
|
||||
%c0 = spv.Constant 0 : i32
|
||||
%addr = spv.mlir.addressof @var01s : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>
|
||||
%ac = spv.AccessChain %addr[%c0, %index] : !spv.ptr<!spv.struct<(!spv.rtarray<i32, stride=4> [0])>, StorageBuffer>, i32, i32
|
||||
%val1 = spv.Load "StorageBuffer" %ac : i32
|
||||
spv.Store "StorageBuffer" %ac, %val0 : i32
|
||||
spv.ReturnValue %val1 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
|
||||
// CHECK-NOT: @var01s
|
||||
// CHECK: spv.GlobalVariable @var01v bind(0, 1) : !spv.ptr<!spv.struct<(!spv.rtarray<vector<4xf32>, stride=16> [0])>, StorageBuffer>
|
||||
// CHECK-NOT: @var01s
|
||||
|
||||
// CHECK: spv.func @different_scalar_type(%{{.+}}: i32, %[[VAL0:.+]]: i32)
|
||||
// CHECK: %[[ADDR:.+]] = spv.mlir.addressof @var01v
|
||||
// CHECK: %[[AC:.+]] = spv.AccessChain %[[ADDR]][%{{.+}}, %{{.+}}, %{{.+}}]
|
||||
// CHECK: %[[VAL1:.+]] = spv.Load "StorageBuffer" %[[AC]] : f32
|
||||
// CHECK: %[[CAST1:.+]] = spv.Bitcast %[[VAL1]] : f32 to i32
|
||||
// CHECK: %[[CAST2:.+]] = spv.Bitcast %[[VAL0]] : i32 to f32
|
||||
// CHECK: spv.Store "StorageBuffer" %[[AC]], %[[CAST2]] : f32
|
||||
// CHECK: spv.ReturnValue %[[CAST1]] : i32
|
||||
Reference in New Issue
Block a user