[mlir][tosa] Add specification versioning to target environment (#156425)

This commit adds a new "specification_version" field to the TOSA target
environment attribute. This allows a user to specify which version of
the TOSA specification they would like to target during lowering.

A leading example in the validation pass has also been added. This
addition adds a version to each profile compliance entry to track which
version of the specification the entry was added. This allows a
backwards compatibility check to be implemented between the target
version and the profile compliance entry version.

For now a default version of "1.0" is assumed. "1.1.draft" is added to
denote an in-development version of the specification targeting the next
release.
This commit is contained in:
Luke Hutton
2025-10-16 09:52:34 +01:00
committed by GitHub
parent c48aa54656
commit 62e786ae63
11 changed files with 826 additions and 359 deletions

View File

@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
/// returned by getDefaultTargetEnv() if not provided.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
/// A thin wrapper around the SpecificationVersion enum to represent
/// and provide utilities around the TOSA specification version.
class TosaSpecificationVersion {
public:
TosaSpecificationVersion(uint32_t major, uint32_t minor)
: majorVersion(major), minorVersion(minor) {}
TosaSpecificationVersion(SpecificationVersion version)
: TosaSpecificationVersion(fromVersionEnum(version)) {}
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
return this->majorVersion == baseVersion.majorVersion &&
this->minorVersion >= baseVersion.minorVersion;
}
uint32_t getMajor() const { return majorVersion; }
uint32_t getMinor() const { return minorVersion; }
private:
uint32_t majorVersion = 0;
uint32_t minorVersion = 0;
static TosaSpecificationVersion
fromVersionEnum(SpecificationVersion version) {
switch (version) {
case SpecificationVersion::V_1_0:
return TosaSpecificationVersion(1, 0);
case SpecificationVersion::V_1_1_DRAFT:
return TosaSpecificationVersion(1, 1);
}
llvm_unreachable("Unknown TOSA version");
}
};
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
/// This class represents the capability enabled in the target implementation
/// such as profile, extension, and level. It's a wrapper class around
/// tosa::TargetEnvAttr.
class TargetEnv {
public:
TargetEnv() {}
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
: level(level) {
: specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}
explicit TargetEnv(TargetEnvAttr targetAttr)
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
targetAttr.getExtensions()) {}
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
targetAttr.getProfiles(), targetAttr.getExtensions()) {}
void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }
// TODO implement the following utilities.
// Version getSpecVersion() const;
SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaLevel getLevel() const {
if (level == Level::eightK)
@@ -105,6 +140,7 @@ public:
}
private:
SpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;

File diff suppressed because it is too large Load Diff

View File

@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
}
//===----------------------------------------------------------------------===//
// TOSA Spec Section 1.5.
// TOSA Profiles and extensions
//
// Profile:
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
def Tosa_ExtensionArrayAttr
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
def Tosa_LevelAttr
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
// The base class for defining op availability dimensions.
class Availability {
// The following are fields for controlling the generated C++ OpInterface.
@@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
let instance = "ref";
}
//===----------------------------------------------------------------------===//
// TOSA Levels
//===----------------------------------------------------------------------===//
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
def Tosa_LevelAttr
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
//===----------------------------------------------------------------------===//
// TOSA Specification versions
//===----------------------------------------------------------------------===//
def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
"SpecificationVersion", "TOSA specification version", "specification_version",
[Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
//===----------------------------------------------------------------------===//
// TOSA target environment.
//===----------------------------------------------------------------------===//
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
let summary = "Target environment information.";
let parameters = ( ins
"SpecificationVersion": $specification_version,
"Level": $level,
ArrayRefParameter<"Profile">: $profiles,
ArrayRefParameter<"Extension">: $extensions
);
let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
"`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
"`extensions` `=` `[` $extensions `]` `>`";
}

View File

@@ -36,12 +36,15 @@ enum CheckCondition {
allOf
};
using VersionedTypeInfo =
std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
template <typename T>
struct OpComplianceInfo {
// Certain operations require multiple modes enabled.
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
SmallVector<T> mode;
SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
SmallVector<VersionedTypeInfo> operandTypeInfoSet;
CheckCondition condition = CheckCondition::anyOf;
};
@@ -130,9 +133,8 @@ public:
// Find the required profiles or extensions from the compliance info according
// to the operand type combination.
template <typename T>
SmallVector<T> findMatchedProfile(Operation *op,
SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition);
OpComplianceInfo<T>
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
switch (ext) {
@@ -168,8 +170,7 @@ public:
private:
template <typename T>
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
CheckCondition &condition);
FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
OperationProfileComplianceMap profileComplianceMap;
OperationExtensionComplianceMap extensionComplianceMap;

View File

@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
];
let options = [
Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
/*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
"The specification version that TOSA operators should conform to.",
[{::llvm::cl::values(
clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
)}]>,
Option<"level", "level", "mlir::tosa::Level",
/*default=*/"mlir::tosa::Level::eightK",
"The TOSA level that operators should conform to. A TOSA level defines "

View File

@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
}
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
return TargetEnvAttr::get(context, Level::eightK,
return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
{Profile::pro_int, Profile::pro_fp}, {});
}
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
}
} // namespace tosa
} // namespace mlir

View File

@@ -61,8 +61,8 @@ public:
ModuleOp mod = getOperation();
MLIRContext *ctx = &getContext();
const auto targetEnvAttr =
TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
const auto targetEnvAttr = TargetEnvAttr::get(
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
}

View File

@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
//===----------------------------------------------------------------------===//
template <typename T>
FailureOr<SmallVector<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op,
CheckCondition &condition) {
FailureOr<OpComplianceInfo<T>>
TosaProfileCompliance::getOperatorDefinition(Operation *op) {
const std::string opName = op->getName().getStringRef().str();
const auto complianceMap = getProfileComplianceMap<T>();
const auto it = complianceMap.find(opName);
if (it == complianceMap.end())
return {};
return findMatchedProfile<T>(op, it->second, condition);
return findMatchedEntry<T>(op, it->second);
}
template <typename T>
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
if (specRequiredModeSet.size() == 0)
return success();
CheckCondition condition = CheckCondition::invalid;
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
if (failed(maybeOpRequiredMode)) {
const auto maybeOpDefinition = getOperatorDefinition<T>(op);
if (failed(maybeOpDefinition)) {
// Operators such as control-flow and shape ops do not have an operand type
// restriction. When the profile compliance information of operation is not
// found, confirm if the target have enabled the profile required from the
// specification.
int mode_count = 0;
int modeCount = 0;
for (const auto &cands : specRequiredModeSet) {
if (targetEnv.allowsAnyOf(cands))
return success();
mode_count += cands.size();
modeCount += cands.size();
}
op->emitOpError() << "illegal: requires"
<< (mode_count > 1 ? " any of " : " ") << "["
<< (modeCount > 1 ? " any of " : " ") << "["
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
", ")
<< "] but not enabled in target\n";
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
// Find the required profiles or extensions according to the operand type
// combination.
const auto opRequiredMode = maybeOpRequiredMode.value();
const auto opDefinition = maybeOpDefinition.value();
const SmallVector<T> opRequiredMode = opDefinition.mode;
const CheckCondition condition = opDefinition.condition;
if (opRequiredMode.size() == 0) {
// No matched restriction found.
return success();
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
}
}
// Ensure the matched op compliance version does not exceed the target
// specification version.
const VersionedTypeInfo versionedTypeInfo =
opDefinition.operandTypeInfoSet[0];
const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
op->emitOpError() << "illegal: the target specification version ("
<< stringifyVersion(targetVersion)
<< ") is not backwards compatible with the op compliance "
"specification version ("
<< stringifyVersion(complianceVersion) << ")\n";
return failure();
}
return success();
}
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
}
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
const auto maybeProfDef = getOperatorDefinition<Profile>(op);
const auto maybeExtDef = getOperatorDefinition<Extension>(op);
if (failed(maybeProfDef) && failed(maybeExtDef))
return success();
const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
(succeeded(maybeExtDef) && !maybeExtDef->empty());
const bool hasEntry =
(succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
(succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
SmallVector<TypeInfo> bestTypeInfo;
const auto searchBestMatch = [&](auto map) {
for (const auto &complianceInfos : map[opName]) {
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
for (const auto &versionedTypeInfos :
complianceInfos.operandTypeInfoSet) {
const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
const int matches = llvm::count_if(
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
return isSameTypeInfo(std::get<0>(zipType),
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
// Find the profiles or extensions requirement according to the signature of
// type of the operand list.
template <typename T>
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
CheckCondition &condition) {
OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
assert(compInfo.size() != 0 &&
"profile-based compliance information is empty");
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
return {};
for (size_t i = 0; i < compInfo.size(); i++) {
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
for (SmallVector<TypeInfo> expected : sets) {
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
for (const auto &set : sets) {
SmallVector<TypeInfo> expected = set.first;
assert(present.size() == expected.size() &&
"the entries for profile-based compliance do not match between "
"the generated metadata and the type definition retrieved from "
" the operation");
bool is_found = true;
bool isFound = true;
// Compare the type signature between the given operation and the
// compliance metadata.
for (size_t j = 0; j < expected.size(); j++) {
if (!isSameTypeInfo(present[j], expected[j])) {
// Verify the next mode set from the list.
is_found = false;
isFound = false;
break;
}
}
if (is_found == true) {
condition = compInfo[i].condition;
return compInfo[i].mode;
if (isFound == true) {
SmallVector<VersionedTypeInfo> typeInfoSet{set};
OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
compInfo[i].condition};
return info;
}
}
}

View File

@@ -1,12 +1,14 @@
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
// -----
// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
// CHECK-LABEL: test_simple
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>

View File

@@ -0,0 +1,21 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
// -----
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
// expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
return %0 : tensor<1x14x28xf16>
}
// -----
func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
// expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}

View File

@@ -0,0 +1,20 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
// -----
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
return %0 : tensor<1x14x28xf16>
}
// -----
// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
return %0 : tensor<1x14x28xf32>
}