Files
llvm/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
Lei Zhang 58df5e6d9a [mlir][spirv] Plumbing target environment into type converter
This commit unifies target environment queries into a new wrapper
class spirv::TargetEnv and shares across various places needing
the functionality. We still create multiple instances of TargetEnv
though given the parent components (type converters, passes,
conversion targets) have different lifetimes.

In the meantime, LowerABIAttributesPass is updated to take into
consideration the target environment, which requires updates to
tests to provide that.

Differential Revision: https://reviews.llvm.org/D76242
2020-03-18 20:11:05 -04:00

185 lines
7.2 KiB
C++

//===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to deduce minimal version/extension/capability
// requirements for a spirv::ModuleOp.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/Passes.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
using namespace mlir;
namespace {
/// Pass to deduce minimal version/extension/capability requirements for a
/// spirv::ModuleOp.
class UpdateVCEPass final
: public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
private:
void runOnOperation() override;
};
} // namespace
/// Checks that `candidates` extension requirements are possible to be satisfied
/// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
/// errors attaching to the given `op` on failures.
///
/// `candidates` is a vector of vector for extension requirements following
/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
/// convention.
static LogicalResult checkAndUpdateExtensionRequirements(
Operation *op, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
llvm::SetVector<spirv::Extension> &deducedExtensions) {
for (const auto &ors : candidates) {
if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
deducedExtensions.insert(*chosen);
} else {
SmallVector<StringRef, 4> extStrings;
for (spirv::Extension ext : ors)
extStrings.push_back(spirv::stringifyExtension(ext));
return op->emitError("'")
<< op->getName() << "' requires at least one extension in ["
<< llvm::join(extStrings, ", ")
<< "] but none allowed in target environment";
}
}
return success();
}
/// Checks that `candidates`capability requirements are possible to be satisfied
/// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
/// errors attaching to the given `op` on failures.
///
/// `candidates` is a vector of vector for capability requirements following
/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
/// convention.
static LogicalResult checkAndUpdateCapabilityRequirements(
Operation *op, const spirv::TargetEnv &targetEnv,
const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
llvm::SetVector<spirv::Capability> &deducedCapabilities) {
for (const auto &ors : candidates) {
if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
deducedCapabilities.insert(*chosen);
} else {
SmallVector<StringRef, 4> capStrings;
for (spirv::Capability cap : ors)
capStrings.push_back(spirv::stringifyCapability(cap));
return op->emitError("'")
<< op->getName() << "' requires at least one capability in ["
<< llvm::join(capStrings, ", ")
<< "] but none allowed in target environment";
}
}
return success();
}
void UpdateVCEPass::runOnOperation() {
spirv::ModuleOp module = getOperation();
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
if (!targetAttr) {
module.emitError("missing 'spv.target_env' attribute");
return signalPassFailure();
}
spirv::TargetEnv targetEnv(targetAttr);
spirv::Version allowedVersion = targetAttr.getVersion();
spirv::Version deducedVersion = spirv::Version::V_1_0;
llvm::SetVector<spirv::Extension> deducedExtensions;
llvm::SetVector<spirv::Capability> deducedCapabilities;
// Walk each SPIR-V op to deduce the minimal version/extension/capability
// requirements.
WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
// Op min version requirements
if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
if (deducedVersion > allowedVersion) {
return op->emitError("'") << op->getName() << "' requires min version "
<< spirv::stringifyVersion(deducedVersion)
<< " but target environment allows up to "
<< spirv::stringifyVersion(allowedVersion);
}
}
// Op extension requirements
if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
if (failed(checkAndUpdateExtensionRequirements(
op, targetEnv, extensions.getExtensions(), deducedExtensions)))
return WalkResult::interrupt();
// Op capability requirements
if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
if (failed(checkAndUpdateCapabilityRequirements(
op, targetEnv, capabilities.getCapabilities(),
deducedCapabilities)))
return WalkResult::interrupt();
SmallVector<Type, 4> valueTypes;
valueTypes.append(op->operand_type_begin(), op->operand_type_end());
valueTypes.append(op->result_type_begin(), op->result_type_end());
// Special treatment for global variables, whose type requirements are
// conveyed by type attributes.
if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
valueTypes.push_back(globalVar.type());
// Requirements from values' types
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
typeExtensions.clear();
valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
if (failed(checkAndUpdateExtensionRequirements(
op, targetEnv, typeExtensions, deducedExtensions)))
return WalkResult::interrupt();
typeCapabilities.clear();
valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
if (failed(checkAndUpdateCapabilityRequirements(
op, targetEnv, typeCapabilities, deducedCapabilities)))
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return signalPassFailure();
// TODO(antiagainst): verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.
auto triple = spirv::VerCapExtAttr::get(
deducedVersion, deducedCapabilities.getArrayRef(),
deducedExtensions.getArrayRef(), &getContext());
module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
}
std::unique_ptr<OpPassBase<spirv::ModuleOp>>
mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
return std::make_unique<UpdateVCEPass>();
}
static PassRegistration<UpdateVCEPass>
pass("spirv-update-vce",
"Deduce and attach minimal (version, capabilities, extensions) "
"requirements to spv.module ops");