mirror of
https://github.com/intel/llvm.git
synced 2026-01-22 23:49:22 +08:00
[mlir][tosa] Add specification versioning to target environment (#156425)
This commit adds a new "specification_version" field to the TOSA target environment attribute. This allows a user to specify which version of the TOSA specification they would like to target during lowering. A leading example in the validation pass has also been added. This addition adds a version to each profile compliance entry to track which version of the specification the entry was added. This allows a backwards compatibility check to be implemented between the target version and the profile compliance entry version. For now a default version of "1.0" is assumed. "1.1.draft" is added to denote an in-development version of the specification targeting the next release.
This commit is contained in:
@@ -50,28 +50,63 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
|
||||
/// returned by getDefaultTargetEnv() if not provided.
|
||||
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
|
||||
|
||||
/// A thin wrapper around the SpecificationVersion enum to represent
|
||||
/// and provide utilities around the TOSA specification version.
|
||||
class TosaSpecificationVersion {
|
||||
public:
|
||||
TosaSpecificationVersion(uint32_t major, uint32_t minor)
|
||||
: majorVersion(major), minorVersion(minor) {}
|
||||
TosaSpecificationVersion(SpecificationVersion version)
|
||||
: TosaSpecificationVersion(fromVersionEnum(version)) {}
|
||||
|
||||
bool isBackwardsCompatibleWith(TosaSpecificationVersion baseVersion) const {
|
||||
return this->majorVersion == baseVersion.majorVersion &&
|
||||
this->minorVersion >= baseVersion.minorVersion;
|
||||
}
|
||||
|
||||
uint32_t getMajor() const { return majorVersion; }
|
||||
uint32_t getMinor() const { return minorVersion; }
|
||||
|
||||
private:
|
||||
uint32_t majorVersion = 0;
|
||||
uint32_t minorVersion = 0;
|
||||
|
||||
static TosaSpecificationVersion
|
||||
fromVersionEnum(SpecificationVersion version) {
|
||||
switch (version) {
|
||||
case SpecificationVersion::V_1_0:
|
||||
return TosaSpecificationVersion(1, 0);
|
||||
case SpecificationVersion::V_1_1_DRAFT:
|
||||
return TosaSpecificationVersion(1, 1);
|
||||
}
|
||||
llvm_unreachable("Unknown TOSA version");
|
||||
}
|
||||
};
|
||||
|
||||
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
|
||||
|
||||
/// This class represents the capability enabled in the target implementation
|
||||
/// such as profile, extension, and level. It's a wrapper class around
|
||||
/// tosa::TargetEnvAttr.
|
||||
class TargetEnv {
|
||||
public:
|
||||
TargetEnv() {}
|
||||
explicit TargetEnv(Level level, const ArrayRef<Profile> &profiles,
|
||||
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
|
||||
const ArrayRef<Profile> &profiles,
|
||||
const ArrayRef<Extension> &extensions)
|
||||
: level(level) {
|
||||
: specificationVersion(specificationVersion), level(level) {
|
||||
enabledProfiles.insert_range(profiles);
|
||||
enabledExtensions.insert_range(extensions);
|
||||
}
|
||||
|
||||
explicit TargetEnv(TargetEnvAttr targetAttr)
|
||||
: TargetEnv(targetAttr.getLevel(), targetAttr.getProfiles(),
|
||||
targetAttr.getExtensions()) {}
|
||||
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
|
||||
targetAttr.getProfiles(), targetAttr.getExtensions()) {}
|
||||
|
||||
void addProfile(Profile p) { enabledProfiles.insert(p); }
|
||||
void addExtension(Extension e) { enabledExtensions.insert(e); }
|
||||
|
||||
// TODO implement the following utilities.
|
||||
// Version getSpecVersion() const;
|
||||
SpecificationVersion getSpecVersion() const { return specificationVersion; }
|
||||
|
||||
TosaLevel getLevel() const {
|
||||
if (level == Level::eightK)
|
||||
@@ -105,6 +140,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
SpecificationVersion specificationVersion;
|
||||
Level level;
|
||||
llvm::SmallSet<Profile, 3> enabledProfiles;
|
||||
llvm::SmallSet<Extension, 13> enabledExtensions;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -221,7 +221,7 @@ class Tosa_I32EnumAttr<string name, string description, string mnemonic,
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Spec Section 1.5.
|
||||
// TOSA Profiles and extensions
|
||||
//
|
||||
// Profile:
|
||||
// INT : Integer Inference. Integer operations, primarily 8 and 32-bit values.
|
||||
@@ -293,12 +293,6 @@ def Tosa_ExtensionAttr
|
||||
def Tosa_ExtensionArrayAttr
|
||||
: TypedArrayAttrBase<Tosa_ExtensionAttr, "TOSA extension array attribute">;
|
||||
|
||||
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
|
||||
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
|
||||
|
||||
def Tosa_LevelAttr
|
||||
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
|
||||
|
||||
// The base class for defining op availability dimensions.
|
||||
class Availability {
|
||||
// The following are fields for controlling the generated C++ OpInterface.
|
||||
@@ -404,18 +398,41 @@ class Extension<list<I32EnumAttrCase> extensions> : Availability {
|
||||
let instance = "ref";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Levels
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tosa_LVL_NONE : I32EnumAttrCase<"none", 0>;
|
||||
def Tosa_LVL_8K : I32EnumAttrCase<"eightK", 1, "8k">;
|
||||
|
||||
def Tosa_LevelAttr
|
||||
: Tosa_I32EnumAttr<"Level", "supported TOSA levels", "level", [Tosa_LVL_NONE, Tosa_LVL_8K]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA Specification versions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Tosa_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "1.0">;
|
||||
def Tosa_V_1_1_DRAFT : I32EnumAttrCase<"V_1_1_DRAFT", 1, "1.1.draft">;
|
||||
|
||||
def Tosa_SpecificationVersion : Tosa_I32EnumAttr<
|
||||
"SpecificationVersion", "TOSA specification version", "specification_version",
|
||||
[Tosa_V_1_0, Tosa_V_1_1_DRAFT]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TOSA target environment.
|
||||
//===----------------------------------------------------------------------===//
|
||||
def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
|
||||
let summary = "Target environment information.";
|
||||
let parameters = ( ins
|
||||
"SpecificationVersion": $specification_version,
|
||||
"Level": $level,
|
||||
ArrayRefParameter<"Profile">: $profiles,
|
||||
ArrayRefParameter<"Extension">: $extensions
|
||||
);
|
||||
|
||||
let assemblyFormat = "`<` `level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
|
||||
let assemblyFormat = "`<` `specification_version` `=` $specification_version `,` "
|
||||
"`level` `=` $level `,` `profiles` `=` `[` $profiles `]` `,` "
|
||||
"`extensions` `=` `[` $extensions `]` `>`";
|
||||
}
|
||||
|
||||
|
||||
@@ -36,12 +36,15 @@ enum CheckCondition {
|
||||
allOf
|
||||
};
|
||||
|
||||
using VersionedTypeInfo =
|
||||
std::pair<SmallVector<TypeInfo>, SpecificationVersion>;
|
||||
|
||||
template <typename T>
|
||||
struct OpComplianceInfo {
|
||||
// Certain operations require multiple modes enabled.
|
||||
// e.g. cast bf16 to fp8e4m3 requires EXT-BF16 and EXT-FP8E4M3.
|
||||
SmallVector<T> mode;
|
||||
SmallVector<SmallVector<TypeInfo>> operandTypeInfoSet;
|
||||
SmallVector<VersionedTypeInfo> operandTypeInfoSet;
|
||||
CheckCondition condition = CheckCondition::anyOf;
|
||||
};
|
||||
|
||||
@@ -130,9 +133,8 @@ public:
|
||||
// Find the required profiles or extensions from the compliance info according
|
||||
// to the operand type combination.
|
||||
template <typename T>
|
||||
SmallVector<T> findMatchedProfile(Operation *op,
|
||||
SmallVector<OpComplianceInfo<T>> compInfo,
|
||||
CheckCondition &condition);
|
||||
OpComplianceInfo<T>
|
||||
findMatchedEntry(Operation *op, SmallVector<OpComplianceInfo<T>> compInfo);
|
||||
|
||||
SmallVector<Profile> getCooperativeProfiles(Extension ext) {
|
||||
switch (ext) {
|
||||
@@ -168,8 +170,7 @@ public:
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
FailureOr<SmallVector<T>> getOperatorDefinition(Operation *op,
|
||||
CheckCondition &condition);
|
||||
FailureOr<OpComplianceInfo<T>> getOperatorDefinition(Operation *op);
|
||||
|
||||
OperationProfileComplianceMap profileComplianceMap;
|
||||
OperationExtensionComplianceMap extensionComplianceMap;
|
||||
|
||||
@@ -137,6 +137,13 @@ def TosaAttachTarget : Pass<"tosa-attach-target", "ModuleOp"> {
|
||||
];
|
||||
|
||||
let options = [
|
||||
Option<"specificationVersion", "specification_version", "mlir::tosa::SpecificationVersion",
|
||||
/*default=*/"mlir::tosa::SpecificationVersion::V_1_0",
|
||||
"The specification version that TOSA operators should conform to.",
|
||||
[{::llvm::cl::values(
|
||||
clEnumValN(mlir::tosa::SpecificationVersion::V_1_0, "1.0", "TOSA Specification version 1.0"),
|
||||
clEnumValN(mlir::tosa::SpecificationVersion::V_1_1_DRAFT, "1.1.draft", "TOSA Specification version 1.1.draft")
|
||||
)}]>,
|
||||
Option<"level", "level", "mlir::tosa::Level",
|
||||
/*default=*/"mlir::tosa::Level::eightK",
|
||||
"The TOSA level that operators should conform to. A TOSA level defines "
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Tosa/IR/TargetEnv.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tosa {
|
||||
@@ -27,7 +28,7 @@ TargetEnvAttr lookupTargetEnv(Operation *op) {
|
||||
}
|
||||
|
||||
TargetEnvAttr getDefaultTargetEnv(MLIRContext *context) {
|
||||
return TargetEnvAttr::get(context, Level::eightK,
|
||||
return TargetEnvAttr::get(context, SpecificationVersion::V_1_0, Level::eightK,
|
||||
{Profile::pro_int, Profile::pro_fp}, {});
|
||||
}
|
||||
|
||||
@@ -38,5 +39,9 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
|
||||
return getDefaultTargetEnv(op->getContext());
|
||||
}
|
||||
|
||||
llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
|
||||
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
|
||||
}
|
||||
|
||||
} // namespace tosa
|
||||
} // namespace mlir
|
||||
|
||||
@@ -61,8 +61,8 @@ public:
|
||||
|
||||
ModuleOp mod = getOperation();
|
||||
MLIRContext *ctx = &getContext();
|
||||
const auto targetEnvAttr =
|
||||
TargetEnvAttr::get(ctx, level, selectedProfiles, selectedExtensions);
|
||||
const auto targetEnvAttr = TargetEnvAttr::get(
|
||||
ctx, specificationVersion, level, selectedProfiles, selectedExtensions);
|
||||
mod->setAttr(TargetEnvAttr::name, targetEnvAttr);
|
||||
}
|
||||
|
||||
|
||||
@@ -335,16 +335,15 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
FailureOr<SmallVector<T>>
|
||||
TosaProfileCompliance::getOperatorDefinition(Operation *op,
|
||||
CheckCondition &condition) {
|
||||
FailureOr<OpComplianceInfo<T>>
|
||||
TosaProfileCompliance::getOperatorDefinition(Operation *op) {
|
||||
const std::string opName = op->getName().getStringRef().str();
|
||||
const auto complianceMap = getProfileComplianceMap<T>();
|
||||
const auto it = complianceMap.find(opName);
|
||||
if (it == complianceMap.end())
|
||||
return {};
|
||||
|
||||
return findMatchedProfile<T>(op, it->second, condition);
|
||||
return findMatchedEntry<T>(op, it->second);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@@ -356,22 +355,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
|
||||
if (specRequiredModeSet.size() == 0)
|
||||
return success();
|
||||
|
||||
CheckCondition condition = CheckCondition::invalid;
|
||||
const auto maybeOpRequiredMode = getOperatorDefinition<T>(op, condition);
|
||||
if (failed(maybeOpRequiredMode)) {
|
||||
const auto maybeOpDefinition = getOperatorDefinition<T>(op);
|
||||
if (failed(maybeOpDefinition)) {
|
||||
// Operators such as control-flow and shape ops do not have an operand type
|
||||
// restriction. When the profile compliance information of operation is not
|
||||
// found, confirm if the target have enabled the profile required from the
|
||||
// specification.
|
||||
int mode_count = 0;
|
||||
int modeCount = 0;
|
||||
for (const auto &cands : specRequiredModeSet) {
|
||||
if (targetEnv.allowsAnyOf(cands))
|
||||
return success();
|
||||
mode_count += cands.size();
|
||||
modeCount += cands.size();
|
||||
}
|
||||
|
||||
op->emitOpError() << "illegal: requires"
|
||||
<< (mode_count > 1 ? " any of " : " ") << "["
|
||||
<< (modeCount > 1 ? " any of " : " ") << "["
|
||||
<< llvm::join(stringifyProfile<T>(specRequiredModeSet),
|
||||
", ")
|
||||
<< "] but not enabled in target\n";
|
||||
@@ -381,7 +379,10 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
|
||||
|
||||
// Find the required profiles or extensions according to the operand type
|
||||
// combination.
|
||||
const auto opRequiredMode = maybeOpRequiredMode.value();
|
||||
const auto opDefinition = maybeOpDefinition.value();
|
||||
const SmallVector<T> opRequiredMode = opDefinition.mode;
|
||||
const CheckCondition condition = opDefinition.condition;
|
||||
|
||||
if (opRequiredMode.size() == 0) {
|
||||
// No matched restriction found.
|
||||
return success();
|
||||
@@ -437,6 +438,21 @@ LogicalResult TosaProfileCompliance::checkProfileOrExtension(
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure the matched op compliance version does not exceed the target
|
||||
// specification version.
|
||||
const VersionedTypeInfo versionedTypeInfo =
|
||||
opDefinition.operandTypeInfoSet[0];
|
||||
const TosaSpecificationVersion complianceVersion{versionedTypeInfo.second};
|
||||
const TosaSpecificationVersion targetVersion{targetEnv.getSpecVersion()};
|
||||
if (!targetVersion.isBackwardsCompatibleWith(complianceVersion)) {
|
||||
op->emitOpError() << "illegal: the target specification version ("
|
||||
<< stringifyVersion(targetVersion)
|
||||
<< ") is not backwards compatible with the op compliance "
|
||||
"specification version ("
|
||||
<< stringifyVersion(complianceVersion) << ")\n";
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -461,14 +477,14 @@ TosaProfileCompliance::checkExtension(Operation *op,
|
||||
}
|
||||
|
||||
LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
|
||||
CheckCondition condition = CheckCondition::invalid;
|
||||
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
|
||||
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
|
||||
const auto maybeProfDef = getOperatorDefinition<Profile>(op);
|
||||
const auto maybeExtDef = getOperatorDefinition<Extension>(op);
|
||||
if (failed(maybeProfDef) && failed(maybeExtDef))
|
||||
return success();
|
||||
|
||||
const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
|
||||
(succeeded(maybeExtDef) && !maybeExtDef->empty());
|
||||
const bool hasEntry =
|
||||
(succeeded(maybeProfDef) && !maybeProfDef->mode.empty()) ||
|
||||
(succeeded(maybeExtDef) && !maybeExtDef->mode.empty());
|
||||
if (!hasEntry) {
|
||||
std::string message;
|
||||
llvm::raw_string_ostream os(message);
|
||||
@@ -488,7 +504,9 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
|
||||
SmallVector<TypeInfo> bestTypeInfo;
|
||||
const auto searchBestMatch = [&](auto map) {
|
||||
for (const auto &complianceInfos : map[opName]) {
|
||||
for (const auto &typeInfos : complianceInfos.operandTypeInfoSet) {
|
||||
for (const auto &versionedTypeInfos :
|
||||
complianceInfos.operandTypeInfoSet) {
|
||||
const SmallVector<TypeInfo> typeInfos = versionedTypeInfos.first;
|
||||
const int matches = llvm::count_if(
|
||||
llvm::zip_equal(current, typeInfos), [&](const auto zipType) {
|
||||
return isSameTypeInfo(std::get<0>(zipType),
|
||||
@@ -520,9 +538,8 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
|
||||
// Find the profiles or extensions requirement according to the signature of
|
||||
// type of the operand list.
|
||||
template <typename T>
|
||||
SmallVector<T> TosaProfileCompliance::findMatchedProfile(
|
||||
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo,
|
||||
CheckCondition &condition) {
|
||||
OpComplianceInfo<T> TosaProfileCompliance::findMatchedEntry(
|
||||
Operation *op, SmallVector<OpComplianceInfo<T>> compInfo) {
|
||||
assert(compInfo.size() != 0 &&
|
||||
"profile-based compliance information is empty");
|
||||
|
||||
@@ -533,27 +550,30 @@ SmallVector<T> TosaProfileCompliance::findMatchedProfile(
|
||||
return {};
|
||||
|
||||
for (size_t i = 0; i < compInfo.size(); i++) {
|
||||
SmallVector<SmallVector<TypeInfo>> sets = compInfo[i].operandTypeInfoSet;
|
||||
for (SmallVector<TypeInfo> expected : sets) {
|
||||
SmallVector<VersionedTypeInfo> sets = compInfo[i].operandTypeInfoSet;
|
||||
for (const auto &set : sets) {
|
||||
SmallVector<TypeInfo> expected = set.first;
|
||||
assert(present.size() == expected.size() &&
|
||||
"the entries for profile-based compliance do not match between "
|
||||
"the generated metadata and the type definition retrieved from "
|
||||
" the operation");
|
||||
|
||||
bool is_found = true;
|
||||
bool isFound = true;
|
||||
// Compare the type signature between the given operation and the
|
||||
// compliance metadata.
|
||||
for (size_t j = 0; j < expected.size(); j++) {
|
||||
if (!isSameTypeInfo(present[j], expected[j])) {
|
||||
// Verify the next mode set from the list.
|
||||
is_found = false;
|
||||
isFound = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_found == true) {
|
||||
condition = compInfo[i].condition;
|
||||
return compInfo[i].mode;
|
||||
if (isFound == true) {
|
||||
SmallVector<VersionedTypeInfo> typeInfoSet{set};
|
||||
OpComplianceInfo<T> info{compInfo[i].mode, typeInfoSet,
|
||||
compInfo[i].condition};
|
||||
return info;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround,dynamic level=none" | FileCheck %s --check-prefix=CHECK-ALL
|
||||
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="level=8k" | FileCheck %s --check-prefix=CHECK-LVL-8K
|
||||
// RUN: mlir-opt %s -split-input-file -tosa-attach-target | FileCheck %s --check-prefix=CHECK-DEFAULT
|
||||
// RUN: mlir-opt %s -split-input-file -tosa-attach-target="specification_version=1.1.draft" | FileCheck %s --check-prefix=CHECK-VERSION-1P1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
|
||||
// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
|
||||
// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<level = "8k", profiles = [], extensions = []>}
|
||||
// CHECK-ALL: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = none, profiles = [pro_int, pro_fp], extensions = [int16, int4, bf16, fp8e4m3, fp8e5m2, fft, variable, controlflow, doubleround, inexactround, dynamic]>}
|
||||
// CHECK-LVL-8K: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
|
||||
// CHECK-DEFAULT: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.0", level = "8k", profiles = [], extensions = []>}
|
||||
// CHECK-VERSION-1P1: module attributes {tosa.target_env = #tosa.target_env<specification_version = "1.1.draft", level = "8k", profiles = [], extensions = []>}
|
||||
// CHECK-LABEL: test_simple
|
||||
func.func @test_simple(%arg0 : tensor<1x1x1x1xf32>, %arg1 : tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> {
|
||||
%1 = tosa.add %arg0, %arg1 : (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.0 profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,dynamic,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
|
||||
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
|
||||
// expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
|
||||
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
|
||||
return %0 : tensor<1x14x28xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
|
||||
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
// expected-error@+1 {{'tosa.matmul' op illegal: the target specification version (1.0) is not backwards compatible with the op compliance specification version (1.1)}}
|
||||
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
|
||||
return %0 : tensor<1x14x28xf32>
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
|
||||
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
|
||||
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E5M2>, tensor<1xf8E4M3FN>, tensor<1xf8E5M2>) -> tensor<1x14x28xf16>
|
||||
return %0 : tensor<1x14x28xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_matmul_fp8_input_fp32_acc_type
|
||||
func.func @test_matmul_fp8_input_fp32_acc_type(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf32> {
|
||||
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
|
||||
%0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E4M3FN>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf32>
|
||||
return %0 : tensor<1x14x28xf32>
|
||||
}
|
||||
Reference in New Issue
Block a user