mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 21:53:12 +08:00
This commit unifies target environment queries into a new wrapper class spirv::TargetEnv and shares across various places needing the functionality. We still create multiple instances of TargetEnv though given the parent components (type converters, passes, conversion targets) have different lifetimes. In the meantime, LowerABIAttributesPass is updated to take into consideration the target environment, which requires updates to tests to provide that. Differential Revision: https://reviews.llvm.org/D76242
185 lines
7.2 KiB
C++
185 lines
7.2 KiB
C++
//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements a pass to deduce minimal version/extension/capability
|
|
// requirements for a spirv::ModuleOp.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SPIRV/Passes.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
|
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
|
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/Visitors.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Pass to deduce minimal version/extension/capability requirements for a
|
|
/// spirv::ModuleOp.
|
|
class UpdateVCEPass final
|
|
: public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
|
|
private:
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
/// Checks that `candidates` extension requirements are possible to be satisfied
|
|
/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
|
|
/// errors attaching to the given `op` on failures.
|
|
///
|
|
/// `candidates` is a vector of vector for extension requirements following
|
|
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
|
|
/// convention.
|
|
static LogicalResult checkAndUpdateExtensionRequirements(
|
|
Operation *op, const spirv::TargetEnv &targetEnv,
|
|
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
|
|
llvm::SetVector<spirv::Extension> &deducedExtensions) {
|
|
for (const auto &ors : candidates) {
|
|
if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
|
|
deducedExtensions.insert(*chosen);
|
|
} else {
|
|
SmallVector<StringRef, 4> extStrings;
|
|
for (spirv::Extension ext : ors)
|
|
extStrings.push_back(spirv::stringifyExtension(ext));
|
|
|
|
return op->emitError("'")
|
|
<< op->getName() << "' requires at least one extension in ["
|
|
<< llvm::join(extStrings, ", ")
|
|
<< "] but none allowed in target environment";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Checks that `candidates`capability requirements are possible to be satisfied
|
|
/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
|
|
/// errors attaching to the given `op` on failures.
|
|
///
|
|
/// `candidates` is a vector of vector for capability requirements following
|
|
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
|
|
/// convention.
|
|
static LogicalResult checkAndUpdateCapabilityRequirements(
|
|
Operation *op, const spirv::TargetEnv &targetEnv,
|
|
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
|
|
llvm::SetVector<spirv::Capability> &deducedCapabilities) {
|
|
for (const auto &ors : candidates) {
|
|
if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
|
|
deducedCapabilities.insert(*chosen);
|
|
} else {
|
|
SmallVector<StringRef, 4> capStrings;
|
|
for (spirv::Capability cap : ors)
|
|
capStrings.push_back(spirv::stringifyCapability(cap));
|
|
|
|
return op->emitError("'")
|
|
<< op->getName() << "' requires at least one capability in ["
|
|
<< llvm::join(capStrings, ", ")
|
|
<< "] but none allowed in target environment";
|
|
}
|
|
}
|
|
return success();
|
|
}
|
|
|
|
void UpdateVCEPass::runOnOperation() {
|
|
spirv::ModuleOp module = getOperation();
|
|
|
|
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
|
|
if (!targetAttr) {
|
|
module.emitError("missing 'spv.target_env' attribute");
|
|
return signalPassFailure();
|
|
}
|
|
|
|
spirv::TargetEnv targetEnv(targetAttr);
|
|
spirv::Version allowedVersion = targetAttr.getVersion();
|
|
|
|
spirv::Version deducedVersion = spirv::Version::V_1_0;
|
|
llvm::SetVector<spirv::Extension> deducedExtensions;
|
|
llvm::SetVector<spirv::Capability> deducedCapabilities;
|
|
|
|
// Walk each SPIR-V op to deduce the minimal version/extension/capability
|
|
// requirements.
|
|
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
|
|
// Op min version requirements
|
|
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
|
|
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
|
|
if (deducedVersion > allowedVersion) {
|
|
return op->emitError("'") << op->getName() << "' requires min version "
|
|
<< spirv::stringifyVersion(deducedVersion)
|
|
<< " but target environment allows up to "
|
|
<< spirv::stringifyVersion(allowedVersion);
|
|
}
|
|
}
|
|
|
|
// Op extension requirements
|
|
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
|
|
if (failed(checkAndUpdateExtensionRequirements(
|
|
op, targetEnv, extensions.getExtensions(), deducedExtensions)))
|
|
return WalkResult::interrupt();
|
|
|
|
// Op capability requirements
|
|
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
|
|
if (failed(checkAndUpdateCapabilityRequirements(
|
|
op, targetEnv, capabilities.getCapabilities(),
|
|
deducedCapabilities)))
|
|
return WalkResult::interrupt();
|
|
|
|
SmallVector<Type, 4> valueTypes;
|
|
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
|
|
valueTypes.append(op->result_type_begin(), op->result_type_end());
|
|
|
|
// Special treatment for global variables, whose type requirements are
|
|
// conveyed by type attributes.
|
|
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
|
|
valueTypes.push_back(globalVar.type());
|
|
|
|
// Requirements from values' types
|
|
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
|
|
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
|
|
for (Type valueType : valueTypes) {
|
|
typeExtensions.clear();
|
|
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
|
|
if (failed(checkAndUpdateExtensionRequirements(
|
|
op, targetEnv, typeExtensions, deducedExtensions)))
|
|
return WalkResult::interrupt();
|
|
|
|
typeCapabilities.clear();
|
|
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
|
|
if (failed(checkAndUpdateCapabilityRequirements(
|
|
op, targetEnv, typeCapabilities, deducedCapabilities)))
|
|
return WalkResult::interrupt();
|
|
}
|
|
|
|
return WalkResult::advance();
|
|
});
|
|
|
|
if (walkResult.wasInterrupted())
|
|
return signalPassFailure();
|
|
|
|
// TODO(antiagainst): verify that the deduced version is consistent with
|
|
// SPIR-V ops' maximal version requirements.
|
|
|
|
auto triple = spirv::VerCapExtAttr::get(
|
|
deducedVersion, deducedCapabilities.getArrayRef(),
|
|
deducedExtensions.getArrayRef(), &getContext());
|
|
module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
|
|
}
|
|
|
|
std::unique_ptr<OpPassBase<spirv::ModuleOp>>
|
|
mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
|
|
return std::make_unique<UpdateVCEPass>();
|
|
}
|
|
|
|
static PassRegistration<UpdateVCEPass>
|
|
pass("spirv-update-vce",
|
|
"Deduce and attach minimal (version, capabilities, extensions) "
|
|
"requirements to spv.module ops");
|