mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[mlir][spirv] Plumbing target environment into type converter
This commit unifies target environment queries into a new wrapper class spirv::TargetEnv and shares across various places needing the functionality. We still create multiple instances of TargetEnv though given the parent components (type converters, passes, conversion targets) have different lifetimes. In the meantime, LowerABIAttributesPass is updated to take into consideration the target environment, which requires updates to tests to provide that. Differential Revision: https://reviews.llvm.org/D76242
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
#ifndef MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
|
||||
#define MLIR_DIALECT_SPIRV_SPIRVLOWERING_H
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@@ -27,7 +28,7 @@ namespace mlir {
|
||||
/// pointers to structs.
|
||||
class SPIRVTypeConverter : public TypeConverter {
|
||||
public:
|
||||
SPIRVTypeConverter();
|
||||
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
|
||||
|
||||
/// Gets the SPIR-V correspondence for the standard index type.
|
||||
static Type getIndexType(MLIRContext *context);
|
||||
@@ -40,6 +41,9 @@ public:
|
||||
/// llvm::None if the memory space does not map to any SPIR-V storage class.
|
||||
static Optional<spirv::StorageClass>
|
||||
getStorageClassForMemorySpace(unsigned space);
|
||||
|
||||
private:
|
||||
spirv::TargetEnv targetEnv;
|
||||
};
|
||||
|
||||
/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
|
||||
@@ -70,11 +74,10 @@ class FuncOp;
|
||||
class SPIRVConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
/// Creates a SPIR-V conversion target for the given target environment.
|
||||
static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetEnv,
|
||||
MLIRContext *context);
|
||||
static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr);
|
||||
|
||||
private:
|
||||
SPIRVConversionTarget(TargetEnvAttr targetEnv, MLIRContext *context);
|
||||
explicit SPIRVConversionTarget(TargetEnvAttr targetAttr);
|
||||
|
||||
// Be explicit that instance of this class cannot be copied or moved: there
|
||||
// are lambdas capturing fields of the instance.
|
||||
@@ -87,9 +90,7 @@ private:
|
||||
/// environment.
|
||||
bool isLegalOp(Operation *op);
|
||||
|
||||
Version givenVersion; /// SPIR-V version to target
|
||||
llvm::SmallSet<Extension, 4> givenExtensions; /// Allowed extensions
|
||||
llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
|
||||
TargetEnv targetEnv;
|
||||
};
|
||||
|
||||
/// Returns the value for the given `builtin` variable. This function gets or
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
#include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
namespace mlir {
|
||||
class Operation;
|
||||
@@ -22,6 +23,38 @@ class Operation;
|
||||
namespace spirv {
|
||||
enum class StorageClass : uint32_t;
|
||||
|
||||
/// A wrapper class around a spirv::TargetEnvAttr to provide query methods for
|
||||
/// allowed version/capabilities/extensions.
|
||||
class TargetEnv {
|
||||
public:
|
||||
explicit TargetEnv(TargetEnvAttr targetAttr);
|
||||
|
||||
Version getVersion();
|
||||
|
||||
/// Returns true if the given capability is allowed.
|
||||
bool allows(Capability) const;
|
||||
/// Returns the first allowed one if any of the given capabilities is allowed.
|
||||
/// Returns llvm::None otherwise.
|
||||
Optional<Capability> allows(ArrayRef<Capability>) const;
|
||||
|
||||
/// Returns true if the given extension is allowed.
|
||||
bool allows(Extension) const;
|
||||
/// Returns the first allowed one if any of the given extensions is allowed.
|
||||
/// Returns llvm::None otherwise.
|
||||
Optional<Extension> allows(ArrayRef<Extension>) const;
|
||||
|
||||
/// Returns the MLIRContext.
|
||||
MLIRContext *getContext();
|
||||
|
||||
/// Allows implicity converting to the underlying spirv::TargetEnvAttr.
|
||||
operator TargetEnvAttr() const { return targetAttr; }
|
||||
|
||||
private:
|
||||
TargetEnvAttr targetAttr;
|
||||
llvm::SmallSet<Extension, 4> givenExtensions; /// Allowed extensions
|
||||
llvm::SmallSet<Capability, 8> givenCapabilities; /// Allowed capabilities
|
||||
};
|
||||
|
||||
/// Returns the attribute name for specifying argument ABI information.
|
||||
StringRef getInterfaceVarABIAttrName();
|
||||
|
||||
|
||||
@@ -52,14 +52,15 @@ void GPUToSPIRVPass::runOnModule() {
|
||||
kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
|
||||
});
|
||||
|
||||
SPIRVTypeConverter typeConverter;
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
spirv::SPIRVConversionTarget::get(targetAttr);
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
OwningRewritePatternList patterns;
|
||||
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
|
||||
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
|
||||
|
||||
std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
|
||||
spirv::lookupTargetEnvOrDefault(module), context);
|
||||
|
||||
if (failed(applyFullConversion(kernelModules, *target, patterns,
|
||||
&typeConverter))) {
|
||||
return signalPassFailure();
|
||||
|
||||
@@ -25,15 +25,15 @@ void LinalgToSPIRVPass::runOnModule() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getModule();
|
||||
|
||||
SPIRVTypeConverter typeConverter;
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
spirv::SPIRVConversionTarget::get(targetAttr);
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
OwningRewritePatternList patterns;
|
||||
populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
|
||||
populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
|
||||
|
||||
auto targetEnv = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
spirv::SPIRVConversionTarget::get(targetEnv, context);
|
||||
|
||||
// Allow builtin ops.
|
||||
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target->addDynamicallyLegalOp<FuncOp>(
|
||||
|
||||
@@ -31,14 +31,15 @@ void ConvertStandardToSPIRVPass::runOnModule() {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp module = getModule();
|
||||
|
||||
SPIRVTypeConverter typeConverter;
|
||||
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
|
||||
std::unique_ptr<ConversionTarget> target =
|
||||
spirv::SPIRVConversionTarget::get(targetAttr);
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetAttr);
|
||||
OwningRewritePatternList patterns;
|
||||
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
|
||||
populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
|
||||
|
||||
std::unique_ptr<ConversionTarget> target = spirv::SPIRVConversionTarget::get(
|
||||
spirv::lookupTargetEnvOrDefault(module), context);
|
||||
|
||||
if (failed(applyPartialConversion(module, *target, patterns))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
@@ -159,7 +159,8 @@ static Optional<int64_t> getTypeNumBytes(Type t) {
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
SPIRVTypeConverter::SPIRVTypeConverter() {
|
||||
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
|
||||
: targetEnv(targetAttr) {
|
||||
addConversion([](Type type) -> Optional<Type> {
|
||||
// If the type is already valid in SPIR-V, directly return.
|
||||
return spirv::SPIRVDialect::isValidType(type) ? type : Optional<Type>();
|
||||
@@ -411,11 +412,10 @@ mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
std::unique_ptr<spirv::SPIRVConversionTarget>
|
||||
spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
|
||||
MLIRContext *context) {
|
||||
spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
|
||||
std::unique_ptr<SPIRVConversionTarget> target(
|
||||
// std::make_unique does not work here because the constructor is private.
|
||||
new SPIRVConversionTarget(targetEnv, context));
|
||||
new SPIRVConversionTarget(targetAttr));
|
||||
SPIRVConversionTarget *targetPtr = target.get();
|
||||
target->addDynamicallyLegalDialect<SPIRVDialect>(
|
||||
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
|
||||
@@ -426,80 +426,57 @@ spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetEnv,
|
||||
}
|
||||
|
||||
spirv::SPIRVConversionTarget::SPIRVConversionTarget(
|
||||
spirv::TargetEnvAttr targetEnv, MLIRContext *context)
|
||||
: ConversionTarget(*context), givenVersion(targetEnv.getVersion()) {
|
||||
for (spirv::Extension ext : targetEnv.getExtensions())
|
||||
givenExtensions.insert(ext);
|
||||
|
||||
// Add extensions implied by the current version.
|
||||
for (spirv::Extension ext : spirv::getImpliedExtensions(givenVersion))
|
||||
givenExtensions.insert(ext);
|
||||
|
||||
for (spirv::Capability cap : targetEnv.getCapabilities()) {
|
||||
givenCapabilities.insert(cap);
|
||||
|
||||
// Add capabilities implied by the current capability.
|
||||
for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
|
||||
givenCapabilities.insert(c);
|
||||
}
|
||||
}
|
||||
spirv::TargetEnvAttr targetAttr)
|
||||
: ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
|
||||
|
||||
/// Checks that `candidates` extension requirements are possible to be satisfied
|
||||
/// with the given `allowedExtensions`.
|
||||
/// with the given `targetEnv`.
|
||||
///
|
||||
/// `candidates` is a vector of vector for extension requirements following
|
||||
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
||||
/// convention.
|
||||
static LogicalResult checkExtensionRequirements(
|
||||
Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
|
||||
Operation *op, const spirv::TargetEnv &targetEnv,
|
||||
const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
|
||||
return allowedExtensions.count(ext);
|
||||
});
|
||||
if (targetEnv.allows(ors))
|
||||
continue;
|
||||
|
||||
if (chosen == ors.end()) {
|
||||
SmallVector<StringRef, 4> extStrings;
|
||||
for (spirv::Extension ext : ors)
|
||||
extStrings.push_back(spirv::stringifyExtension(ext));
|
||||
SmallVector<StringRef, 4> extStrings;
|
||||
for (spirv::Extension ext : ors)
|
||||
extStrings.push_back(spirv::stringifyExtension(ext));
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< "illegal: requires at least one extension in ["
|
||||
<< llvm::join(extStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< " illegal: requires at least one extension in ["
|
||||
<< llvm::join(extStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Checks that `candidates`capability requirements are possible to be satisfied
|
||||
/// with the given `allowedCapabilities`.
|
||||
/// with the given `isAllowedFn`.
|
||||
///
|
||||
/// `candidates` is a vector of vector for capability requirements following
|
||||
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
|
||||
/// convention.
|
||||
static LogicalResult checkCapabilityRequirements(
|
||||
Operation *op,
|
||||
const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
|
||||
Operation *op, const spirv::TargetEnv &targetEnv,
|
||||
const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
|
||||
return allowedCapabilities.count(cap);
|
||||
});
|
||||
if (targetEnv.allows(ors))
|
||||
continue;
|
||||
|
||||
if (chosen == ors.end()) {
|
||||
SmallVector<StringRef, 4> capStrings;
|
||||
for (spirv::Capability cap : ors)
|
||||
capStrings.push_back(spirv::stringifyCapability(cap));
|
||||
SmallVector<StringRef, 4> capStrings;
|
||||
for (spirv::Capability cap : ors)
|
||||
capStrings.push_back(spirv::stringifyCapability(cap));
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< op->getName()
|
||||
<< "illegal: requires at least one capability in ["
|
||||
<< llvm::join(capStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << op->getName()
|
||||
<< " illegal: requires at least one capability in ["
|
||||
<< llvm::join(capStrings, ", ")
|
||||
<< "] but none allowed in target environment\n");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -509,7 +486,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
// QueryMinVersionInterface/QueryMaxVersionInterface are available to all
|
||||
// SPIR-V versions.
|
||||
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
|
||||
if (minVersion.getMinVersion() > givenVersion) {
|
||||
if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< op->getName() << " illegal: requiring min version "
|
||||
<< spirv::stringifyVersion(minVersion.getMinVersion())
|
||||
@@ -517,7 +494,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
return false;
|
||||
}
|
||||
if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
|
||||
if (maxVersion.getMaxVersion() < givenVersion) {
|
||||
if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< op->getName() << " illegal: requiring max version "
|
||||
<< spirv::stringifyVersion(maxVersion.getMaxVersion())
|
||||
@@ -529,7 +506,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
// implementing QueryExtensionInterface do not require extensions to be
|
||||
// available.
|
||||
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
|
||||
if (failed(checkExtensionRequirements(op, this->givenExtensions,
|
||||
if (failed(checkExtensionRequirements(op, this->targetEnv,
|
||||
extensions.getExtensions())))
|
||||
return false;
|
||||
|
||||
@@ -537,7 +514,7 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
// implementing QueryCapabilityInterface do not require capabilities to be
|
||||
// available.
|
||||
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
|
||||
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
|
||||
if (failed(checkCapabilityRequirements(op, this->targetEnv,
|
||||
capabilities.getCapabilities())))
|
||||
return false;
|
||||
|
||||
@@ -557,14 +534,13 @@ bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
|
||||
for (Type valueType : valueTypes) {
|
||||
typeExtensions.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
|
||||
if (failed(checkExtensionRequirements(op, this->givenExtensions,
|
||||
typeExtensions)))
|
||||
if (failed(checkExtensionRequirements(op, this->targetEnv, typeExtensions)))
|
||||
return false;
|
||||
|
||||
typeCapabilities.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
|
||||
if (failed(checkCapabilityRequirements(op, this->givenCapabilities,
|
||||
typeCapabilities)))
|
||||
if (failed(
|
||||
checkCapabilityRequirements(op, this->targetEnv, typeCapabilities)))
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,67 @@
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TargetEnv
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr)
|
||||
: targetAttr(targetAttr) {
|
||||
for (spirv::Extension ext : targetAttr.getExtensions())
|
||||
givenExtensions.insert(ext);
|
||||
|
||||
// Add extensions implied by the current version.
|
||||
for (spirv::Extension ext :
|
||||
spirv::getImpliedExtensions(targetAttr.getVersion()))
|
||||
givenExtensions.insert(ext);
|
||||
|
||||
for (spirv::Capability cap : targetAttr.getCapabilities()) {
|
||||
givenCapabilities.insert(cap);
|
||||
|
||||
// Add capabilities implied by the current capability.
|
||||
for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
|
||||
givenCapabilities.insert(c);
|
||||
}
|
||||
}
|
||||
|
||||
spirv::Version spirv::TargetEnv::getVersion() {
|
||||
return targetAttr.getVersion();
|
||||
}
|
||||
|
||||
bool spirv::TargetEnv::allows(spirv::Capability capability) const {
|
||||
return givenCapabilities.count(capability);
|
||||
}
|
||||
|
||||
Optional<spirv::Capability>
|
||||
spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const {
|
||||
auto chosen = llvm::find_if(caps, [this](spirv::Capability cap) {
|
||||
return givenCapabilities.count(cap);
|
||||
});
|
||||
if (chosen != caps.end())
|
||||
return *chosen;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
bool spirv::TargetEnv::allows(spirv::Extension extension) const {
|
||||
return givenExtensions.count(extension);
|
||||
}
|
||||
|
||||
Optional<spirv::Extension>
|
||||
spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const {
|
||||
auto chosen = llvm::find_if(exts, [this](spirv::Extension ext) {
|
||||
return givenExtensions.count(ext);
|
||||
});
|
||||
if (chosen != exts.end())
|
||||
return *chosen;
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
MLIRContext *spirv::TargetEnv::getContext() { return targetAttr.getContext(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
StringRef spirv::getInterfaceVarABIAttrName() {
|
||||
return "spv.interface_var_abi";
|
||||
}
|
||||
|
||||
@@ -224,7 +224,9 @@ void LowerABIAttributesPass::runOnOperation() {
|
||||
spirv::ModuleOp module = getOperation();
|
||||
MLIRContext *context = &getContext();
|
||||
|
||||
SPIRVTypeConverter typeConverter;
|
||||
spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(module));
|
||||
|
||||
SPIRVTypeConverter typeConverter(targetEnv);
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ProcessInterfaceVarABI>(context, typeConverter);
|
||||
|
||||
|
||||
@@ -34,22 +34,18 @@ private:
|
||||
} // namespace
|
||||
|
||||
/// Checks that `candidates` extension requirements are possible to be satisfied
|
||||
/// with the given `allowedExtensions` and updates `deducedExtensions` if so.
|
||||
/// Emits errors attaching to the given `op` on failures.
|
||||
/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
|
||||
/// errors attaching to the given `op` on failures.
|
||||
///
|
||||
/// `candidates` is a vector of vector for extension requirements following
|
||||
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
||||
/// convention.
|
||||
static LogicalResult checkAndUpdateExtensionRequirements(
|
||||
Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
|
||||
Operation *op, const spirv::TargetEnv &targetEnv,
|
||||
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
|
||||
llvm::SetVector<spirv::Extension> &deducedExtensions) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
|
||||
return allowedExtensions.count(ext);
|
||||
});
|
||||
|
||||
if (chosen != ors.end()) {
|
||||
if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
|
||||
deducedExtensions.insert(*chosen);
|
||||
} else {
|
||||
SmallVector<StringRef, 4> extStrings;
|
||||
@@ -66,23 +62,18 @@ static LogicalResult checkAndUpdateExtensionRequirements(
|
||||
}
|
||||
|
||||
/// Checks that `candidates`capability requirements are possible to be satisfied
|
||||
/// with the given `allowedCapabilities` and updates `deducedCapabilities` if
|
||||
/// so. Emits errors attaching to the given `op` on failures.
|
||||
/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
|
||||
/// errors attaching to the given `op` on failures.
|
||||
///
|
||||
/// `candidates` is a vector of vector for capability requirements following
|
||||
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
|
||||
/// convention.
|
||||
static LogicalResult checkAndUpdateCapabilityRequirements(
|
||||
Operation *op,
|
||||
const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
|
||||
Operation *op, const spirv::TargetEnv &targetEnv,
|
||||
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
|
||||
llvm::SetVector<spirv::Capability> &deducedCapabilities) {
|
||||
for (const auto &ors : candidates) {
|
||||
auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
|
||||
return allowedCapabilities.count(cap);
|
||||
});
|
||||
|
||||
if (chosen != ors.end()) {
|
||||
if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
|
||||
deducedCapabilities.insert(*chosen);
|
||||
} else {
|
||||
SmallVector<StringRef, 4> capStrings;
|
||||
@@ -101,32 +92,14 @@ static LogicalResult checkAndUpdateCapabilityRequirements(
|
||||
void UpdateVCEPass::runOnOperation() {
|
||||
spirv::ModuleOp module = getOperation();
|
||||
|
||||
spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
|
||||
if (!targetEnv) {
|
||||
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
|
||||
if (!targetAttr) {
|
||||
module.emitError("missing 'spv.target_env' attribute");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
spirv::Version allowedVersion = targetEnv.getVersion();
|
||||
|
||||
// Build a set for available extensions in the target environment.
|
||||
llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
|
||||
for (spirv::Extension ext : targetEnv.getExtensions())
|
||||
allowedExtensions.insert(ext);
|
||||
|
||||
// Add extensions implied by the current version.
|
||||
for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
|
||||
allowedExtensions.insert(ext);
|
||||
|
||||
// Build a set for available capabilities in the target environment.
|
||||
llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
|
||||
for (spirv::Capability cap : targetEnv.getCapabilities()) {
|
||||
allowedCapabilities.insert(cap);
|
||||
|
||||
// Add capabilities implied by the current capability.
|
||||
for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
|
||||
allowedCapabilities.insert(c);
|
||||
}
|
||||
spirv::TargetEnv targetEnv(targetAttr);
|
||||
spirv::Version allowedVersion = targetAttr.getVersion();
|
||||
|
||||
spirv::Version deducedVersion = spirv::Version::V_1_0;
|
||||
llvm::SetVector<spirv::Extension> deducedExtensions;
|
||||
@@ -148,15 +121,14 @@ void UpdateVCEPass::runOnOperation() {
|
||||
|
||||
// Op extension requirements
|
||||
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
|
||||
if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions,
|
||||
extensions.getExtensions(),
|
||||
deducedExtensions)))
|
||||
if (failed(checkAndUpdateExtensionRequirements(
|
||||
op, targetEnv, extensions.getExtensions(), deducedExtensions)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
// Op capability requirements
|
||||
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
|
||||
if (failed(checkAndUpdateCapabilityRequirements(
|
||||
op, allowedCapabilities, capabilities.getCapabilities(),
|
||||
op, targetEnv, capabilities.getCapabilities(),
|
||||
deducedCapabilities)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
@@ -176,13 +148,13 @@ void UpdateVCEPass::runOnOperation() {
|
||||
typeExtensions.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
|
||||
if (failed(checkAndUpdateExtensionRequirements(
|
||||
op, allowedExtensions, typeExtensions, deducedExtensions)))
|
||||
op, targetEnv, typeExtensions, deducedExtensions)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
typeCapabilities.clear();
|
||||
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
|
||||
if (failed(checkAndUpdateCapabilityRequirements(
|
||||
op, allowedCapabilities, typeCapabilities, deducedCapabilities)))
|
||||
op, targetEnv, typeCapabilities, deducedCapabilities)))
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
spv.module Logical GLSL450 {
|
||||
// CHECK-DAG: spv.globalVariable [[VAR0:@.*]] bind(0, 0) : !spv.ptr<!spv.struct<f32 [0]>, StorageBuffer>
|
||||
@@ -24,4 +31,6 @@ spv.module Logical GLSL450 {
|
||||
}
|
||||
// CHECK: spv.EntryPoint "GLCompute" [[FN]]
|
||||
// CHECK: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
|
||||
}
|
||||
} // end spv.module
|
||||
|
||||
} // end module
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
module attributes {
|
||||
spv.target_env = #spv.target_env<
|
||||
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
|
||||
{max_compute_workgroup_invocations = 128 : i32,
|
||||
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>
|
||||
} {
|
||||
|
||||
// CHECK-LABEL: spv.module
|
||||
spv.module Logical GLSL450 {
|
||||
// CHECK-DAG: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize")
|
||||
@@ -119,4 +126,6 @@ spv.module Logical GLSL450 {
|
||||
}
|
||||
// CHECK: spv.EntryPoint "GLCompute" [[FN]], [[WORKGROUPID]], [[LOCALINVOCATIONID]], [[NUMWORKGROUPS]], [[WORKGROUPSIZE]]
|
||||
// CHECK-NEXT: spv.ExecutionMode [[FN]] "LocalSize", 32, 1, 1
|
||||
}
|
||||
} // end spv.module
|
||||
|
||||
} // end module
|
||||
|
||||
@@ -130,7 +130,12 @@ void ConvertToTargetEnv::runOnFunction() {
|
||||
auto targetEnv = fn.getOperation()
|
||||
->getAttr(spirv::getTargetEnvAttrName())
|
||||
.cast<spirv::TargetEnvAttr>();
|
||||
auto target = spirv::SPIRVConversionTarget::get(targetEnv, context);
|
||||
if (!targetEnv) {
|
||||
fn.emitError("missing 'spv.target_env' attribute");
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
auto target = spirv::SPIRVConversionTarget::get(targetEnv);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
|
||||
|
||||
Reference in New Issue
Block a user