mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 01:58:44 +08:00
[mlir][Attribute] Remove usages of Attribute::getKind
This is in preparation for removing the use of "kinds" within attributes and types in MLIR. Differential Revision: https://reviews.llvm.org/D85370
This commit is contained in:
@@ -661,10 +661,7 @@ public:
|
||||
function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
|
||||
attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
|
||||
}
|
||||
static bool classof(Attribute attr);
|
||||
|
||||
protected:
|
||||
/// Returns the 1 dimensional flattened row-major index from the given
|
||||
@@ -729,10 +726,7 @@ public:
|
||||
using ElementsAttr::ElementsAttr;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
|
||||
attr.getKind() == StandardAttributes::DenseStringElements;
|
||||
}
|
||||
static bool classof(Attribute attr);
|
||||
|
||||
/// Constructs a dense elements attribute from an array of element values.
|
||||
/// Each element attribute value is expected to be an element of 'type'.
|
||||
@@ -1513,12 +1507,10 @@ class ElementsAttrIterator
|
||||
template <typename RetT, template <typename> class ProcessFn,
|
||||
typename... Args>
|
||||
RetT process(Args &... args) const {
|
||||
switch (attrKind) {
|
||||
case StandardAttributes::DenseIntOrFPElements:
|
||||
if (attr.isa<DenseElementsAttr>())
|
||||
return ProcessFn<DenseIteratorT>()(args...);
|
||||
case StandardAttributes::SparseElements:
|
||||
if (attr.isa<SparseElementsAttr>())
|
||||
return ProcessFn<SparseIteratorT>()(args...);
|
||||
}
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
@@ -1543,22 +1535,21 @@ class ElementsAttrIterator
|
||||
};
|
||||
|
||||
public:
|
||||
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
|
||||
: attrKind(rhs.attrKind) {
|
||||
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
|
||||
process<void, ConstructIter>(it, rhs.it);
|
||||
}
|
||||
~ElementsAttrIterator() { process<void, DestructIter>(it); }
|
||||
|
||||
/// Methods necessary to support random access iteration.
|
||||
ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
|
||||
assert(attrKind == rhs.attrKind && "incompatible iterators");
|
||||
assert(attr == rhs.attr && "incompatible iterators");
|
||||
return process<ptrdiff_t, Minus>(it, rhs.it);
|
||||
}
|
||||
bool operator==(const ElementsAttrIterator<T> &rhs) const {
|
||||
return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
|
||||
return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
|
||||
}
|
||||
bool operator<(const ElementsAttrIterator<T> &rhs) const {
|
||||
assert(attrKind == rhs.attrKind && "incompatible iterators");
|
||||
assert(attr == rhs.attr && "incompatible iterators");
|
||||
return process<bool, std::less>(it, rhs.it);
|
||||
}
|
||||
ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
|
||||
@@ -1575,14 +1566,14 @@ public:
|
||||
|
||||
private:
|
||||
template <typename IteratorT>
|
||||
ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
|
||||
: attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
|
||||
ElementsAttrIterator(Attribute attr, IteratorT &&it)
|
||||
: attr(attr), it(std::forward<IteratorT>(it)) {}
|
||||
|
||||
/// Allow accessing the constructor.
|
||||
friend ElementsAttr;
|
||||
|
||||
/// The kind of derived elements attribute.
|
||||
unsigned attrKind;
|
||||
/// The parent elements attribute.
|
||||
Attribute attr;
|
||||
|
||||
/// A union containing the specific iterators for each derived kind.
|
||||
Iterator it;
|
||||
@@ -1599,13 +1590,13 @@ template <typename T>
|
||||
auto ElementsAttr::getValues() const -> iterator_range<T> {
|
||||
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
|
||||
auto values = denseAttr.getValues<T>();
|
||||
return {iterator<T>(getKind(), values.begin()),
|
||||
iterator<T>(getKind(), values.end())};
|
||||
return {iterator<T>(*this, values.begin()),
|
||||
iterator<T>(*this, values.end())};
|
||||
}
|
||||
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
|
||||
auto values = sparseAttr.getValues<T>();
|
||||
return {iterator<T>(getKind(), values.begin()),
|
||||
iterator<T>(getKind(), values.end())};
|
||||
return {iterator<T>(*this, values.begin()),
|
||||
iterator<T>(*this, values.end())};
|
||||
}
|
||||
llvm_unreachable("unexpected attribute kind");
|
||||
}
|
||||
|
||||
@@ -42,10 +42,7 @@ public:
|
||||
using Attribute::Attribute;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
|
||||
attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
|
||||
}
|
||||
static bool classof(Attribute attr);
|
||||
};
|
||||
|
||||
/// This class defines the main interface for locations in MLIR and acts as a
|
||||
|
||||
@@ -1472,18 +1472,15 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
||||
// ODS already generates checks to make sure the result type is valid. We just
|
||||
// need to additionally check that the value's attribute type is consistent
|
||||
// with the result type.
|
||||
switch (value.getKind()) {
|
||||
case StandardAttributes::Integer:
|
||||
case StandardAttributes::Float: {
|
||||
if (value.isa<IntegerAttr, FloatAttr>()) {
|
||||
if (valueType != opType)
|
||||
return constOp.emitOpError("result type (")
|
||||
<< opType << ") does not match value type (" << valueType << ")";
|
||||
return success();
|
||||
} break;
|
||||
case StandardAttributes::DenseIntOrFPElements:
|
||||
case StandardAttributes::SparseElements: {
|
||||
}
|
||||
if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
|
||||
if (valueType == opType)
|
||||
break;
|
||||
return success();
|
||||
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
||||
auto shapedType = valueType.dyn_cast<ShapedType>();
|
||||
if (!arrayType) {
|
||||
@@ -1497,9 +1494,8 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
||||
numElements *= t.getNumElements();
|
||||
opElemType = t.getElementType();
|
||||
}
|
||||
if (!opElemType.isIntOrFloat()) {
|
||||
if (!opElemType.isIntOrFloat())
|
||||
return constOp.emitOpError("only support nested array result type");
|
||||
}
|
||||
|
||||
auto valueElemType = shapedType.getElementType();
|
||||
if (valueElemType != opElemType) {
|
||||
@@ -1513,26 +1509,24 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
||||
<< numElements << ") does not match value number of elements ("
|
||||
<< shapedType.getNumElements() << ")";
|
||||
}
|
||||
} break;
|
||||
case StandardAttributes::Array: {
|
||||
return success();
|
||||
}
|
||||
if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
|
||||
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
||||
if (!arrayType)
|
||||
return constOp.emitOpError(
|
||||
"must have spv.array result type for array value");
|
||||
auto elemType = arrayType.getElementType();
|
||||
for (auto element : value.cast<ArrayAttr>().getValue()) {
|
||||
Type elemType = arrayType.getElementType();
|
||||
for (Attribute element : attayAttr.getValue()) {
|
||||
if (element.getType() != elemType)
|
||||
return constOp.emitOpError("has array element whose type (")
|
||||
<< element.getType()
|
||||
<< ") does not match the result element type (" << elemType
|
||||
<< ')';
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
return constOp.emitOpError("cannot have value of type ") << valueType;
|
||||
return success();
|
||||
}
|
||||
|
||||
return success();
|
||||
return constOp.emitOpError("cannot have value of type ") << valueType;
|
||||
}
|
||||
|
||||
bool spirv::ConstantOp::isBuildableWith(Type type) {
|
||||
@@ -2619,19 +2613,14 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
|
||||
return constOp.emitOpError("SpecId cannot be negative");
|
||||
|
||||
auto value = constOp.default_value();
|
||||
|
||||
switch (value.getKind()) {
|
||||
case StandardAttributes::Integer:
|
||||
case StandardAttributes::Float: {
|
||||
if (value.isa<IntegerAttr, FloatAttr>()) {
|
||||
// Make sure bitwidth is allowed.
|
||||
if (!value.getType().isa<spirv::SPIRVType>())
|
||||
return constOp.emitOpError("default value bitwidth disallowed");
|
||||
return success();
|
||||
}
|
||||
default:
|
||||
return constOp.emitOpError(
|
||||
"default value can only be a bool, integer, or float scalar");
|
||||
}
|
||||
return constOp.emitOpError(
|
||||
"default value can only be a bool, integer, or float scalar");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/SaveAndRestore.h"
|
||||
@@ -1019,76 +1020,67 @@ void ModulePrinter::printTrailingLocation(Location loc) {
|
||||
}
|
||||
|
||||
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
|
||||
switch (loc.getKind()) {
|
||||
case StandardAttributes::OpaqueLocation:
|
||||
printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
|
||||
break;
|
||||
case StandardAttributes::UnknownLocation:
|
||||
if (pretty)
|
||||
os << "[unknown]";
|
||||
else
|
||||
os << "unknown";
|
||||
break;
|
||||
case StandardAttributes::FileLineColLocation: {
|
||||
auto fileLoc = loc.cast<FileLineColLoc>();
|
||||
auto mayQuote = pretty ? "" : "\"";
|
||||
os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
|
||||
<< fileLoc.getLine() << ':' << fileLoc.getColumn();
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::NameLocation: {
|
||||
auto nameLoc = loc.cast<NameLoc>();
|
||||
os << '\"' << nameLoc.getName() << '\"';
|
||||
TypeSwitch<LocationAttr>(loc)
|
||||
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
|
||||
printLocationInternal(loc.getFallbackLocation(), pretty);
|
||||
})
|
||||
.Case<UnknownLoc>([&](UnknownLoc loc) {
|
||||
if (pretty)
|
||||
os << "[unknown]";
|
||||
else
|
||||
os << "unknown";
|
||||
})
|
||||
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
|
||||
StringRef mayQuote = pretty ? "" : "\"";
|
||||
os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
|
||||
<< ':' << loc.getColumn();
|
||||
})
|
||||
.Case<NameLoc>([&](NameLoc loc) {
|
||||
os << '\"' << loc.getName() << '\"';
|
||||
|
||||
// Print the child if it isn't unknown.
|
||||
auto childLoc = nameLoc.getChildLoc();
|
||||
if (!childLoc.isa<UnknownLoc>()) {
|
||||
os << '(';
|
||||
printLocationInternal(childLoc, pretty);
|
||||
os << ')';
|
||||
}
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::CallSiteLocation: {
|
||||
auto callLocation = loc.cast<CallSiteLoc>();
|
||||
auto caller = callLocation.getCaller();
|
||||
auto callee = callLocation.getCallee();
|
||||
if (!pretty)
|
||||
os << "callsite(";
|
||||
printLocationInternal(callee, pretty);
|
||||
if (pretty) {
|
||||
if (callee.isa<NameLoc>()) {
|
||||
if (caller.isa<FileLineColLoc>()) {
|
||||
os << " at ";
|
||||
} else {
|
||||
os << newLine << " at ";
|
||||
// Print the child if it isn't unknown.
|
||||
auto childLoc = loc.getChildLoc();
|
||||
if (!childLoc.isa<UnknownLoc>()) {
|
||||
os << '(';
|
||||
printLocationInternal(childLoc, pretty);
|
||||
os << ')';
|
||||
}
|
||||
} else {
|
||||
os << newLine << " at ";
|
||||
}
|
||||
} else {
|
||||
os << " at ";
|
||||
}
|
||||
printLocationInternal(caller, pretty);
|
||||
if (!pretty)
|
||||
os << ")";
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::FusedLocation: {
|
||||
auto fusedLoc = loc.cast<FusedLoc>();
|
||||
if (!pretty)
|
||||
os << "fused";
|
||||
if (auto metadata = fusedLoc.getMetadata())
|
||||
os << '<' << metadata << '>';
|
||||
os << '[';
|
||||
interleave(
|
||||
fusedLoc.getLocations(),
|
||||
[&](Location loc) { printLocationInternal(loc, pretty); },
|
||||
[&]() { os << ", "; });
|
||||
os << ']';
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
.Case<CallSiteLoc>([&](CallSiteLoc loc) {
|
||||
Location caller = loc.getCaller();
|
||||
Location callee = loc.getCallee();
|
||||
if (!pretty)
|
||||
os << "callsite(";
|
||||
printLocationInternal(callee, pretty);
|
||||
if (pretty) {
|
||||
if (callee.isa<NameLoc>()) {
|
||||
if (caller.isa<FileLineColLoc>()) {
|
||||
os << " at ";
|
||||
} else {
|
||||
os << newLine << " at ";
|
||||
}
|
||||
} else {
|
||||
os << newLine << " at ";
|
||||
}
|
||||
} else {
|
||||
os << " at ";
|
||||
}
|
||||
printLocationInternal(caller, pretty);
|
||||
if (!pretty)
|
||||
os << ")";
|
||||
})
|
||||
.Case<FusedLoc>([&](FusedLoc loc) {
|
||||
if (!pretty)
|
||||
os << "fused";
|
||||
if (Attribute metadata = loc.getMetadata())
|
||||
os << '<' << metadata << '>';
|
||||
os << '[';
|
||||
interleave(
|
||||
loc.getLocations(),
|
||||
[&](Location loc) { printLocationInternal(loc, pretty); },
|
||||
[&]() { os << ", "; });
|
||||
os << ']';
|
||||
});
|
||||
}
|
||||
|
||||
/// Print a floating point value in a way that the parser will be able to
|
||||
@@ -1305,27 +1297,19 @@ void ModulePrinter::printAttribute(Attribute attr,
|
||||
}
|
||||
|
||||
auto attrType = attr.getType();
|
||||
switch (attr.getKind()) {
|
||||
default:
|
||||
return printDialectAttribute(attr);
|
||||
|
||||
case StandardAttributes::Opaque: {
|
||||
auto opaqueAttr = attr.cast<OpaqueAttr>();
|
||||
if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
|
||||
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
|
||||
opaqueAttr.getAttrData());
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::Unit:
|
||||
} else if (attr.isa<UnitAttr>()) {
|
||||
os << "unit";
|
||||
break;
|
||||
case StandardAttributes::Dictionary:
|
||||
return;
|
||||
} else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
|
||||
os << '{';
|
||||
interleaveComma(attr.cast<DictionaryAttr>().getValue(),
|
||||
interleaveComma(dictAttr.getValue(),
|
||||
[&](NamedAttribute attr) { printNamedAttribute(attr); });
|
||||
os << '}';
|
||||
break;
|
||||
case StandardAttributes::Integer: {
|
||||
auto intAttr = attr.cast<IntegerAttr>();
|
||||
|
||||
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
|
||||
if (attrType.isSignlessInteger(1)) {
|
||||
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
|
||||
|
||||
@@ -1343,114 +1327,98 @@ void ModulePrinter::printAttribute(Attribute attr,
|
||||
// IntegerAttr elides the type if I64.
|
||||
if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
|
||||
return;
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::Float: {
|
||||
auto floatAttr = attr.cast<FloatAttr>();
|
||||
|
||||
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
|
||||
printFloatValue(floatAttr.getValue(), os);
|
||||
|
||||
// FloatAttr elides the type if F64.
|
||||
if (typeElision == AttrTypeElision::May && attrType.isF64())
|
||||
return;
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::String:
|
||||
|
||||
} else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
|
||||
os << '"';
|
||||
printEscapedString(attr.cast<StringAttr>().getValue(), os);
|
||||
printEscapedString(strAttr.getValue(), os);
|
||||
os << '"';
|
||||
break;
|
||||
case StandardAttributes::Array:
|
||||
|
||||
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||
os << '[';
|
||||
interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
|
||||
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
|
||||
printAttribute(attr, AttrTypeElision::May);
|
||||
});
|
||||
os << ']';
|
||||
break;
|
||||
case StandardAttributes::AffineMap:
|
||||
|
||||
} else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
|
||||
os << "affine_map<";
|
||||
attr.cast<AffineMapAttr>().getValue().print(os);
|
||||
affineMapAttr.getValue().print(os);
|
||||
os << '>';
|
||||
|
||||
// AffineMap always elides the type.
|
||||
return;
|
||||
case StandardAttributes::IntegerSet:
|
||||
|
||||
} else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
|
||||
os << "affine_set<";
|
||||
attr.cast<IntegerSetAttr>().getValue().print(os);
|
||||
integerSetAttr.getValue().print(os);
|
||||
os << '>';
|
||||
|
||||
// IntegerSet always elides the type.
|
||||
return;
|
||||
case StandardAttributes::Type:
|
||||
printType(attr.cast<TypeAttr>().getValue());
|
||||
break;
|
||||
case StandardAttributes::SymbolRef: {
|
||||
auto refAttr = attr.dyn_cast<SymbolRefAttr>();
|
||||
|
||||
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
|
||||
printType(typeAttr.getValue());
|
||||
|
||||
} else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
|
||||
printSymbolReference(refAttr.getRootReference(), os);
|
||||
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
|
||||
os << "::";
|
||||
printSymbolReference(nestedRef.getValue(), os);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::OpaqueElements: {
|
||||
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
||||
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
break;
|
||||
}
|
||||
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
|
||||
os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::DenseIntOrFPElements: {
|
||||
auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
|
||||
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
break;
|
||||
}
|
||||
os << "dense<";
|
||||
printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::DenseStringElements: {
|
||||
auto eltsAttr = attr.cast<DenseStringElementsAttr>();
|
||||
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
break;
|
||||
}
|
||||
os << "dense<";
|
||||
printDenseStringElementsAttr(eltsAttr);
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::SparseElements: {
|
||||
auto elementsAttr = attr.cast<SparseElementsAttr>();
|
||||
if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
|
||||
printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
|
||||
printElidedElementsAttr(os);
|
||||
break;
|
||||
}
|
||||
os << "sparse<";
|
||||
DenseIntElementsAttr indices = elementsAttr.getIndices();
|
||||
if (indices.getNumElements() != 0) {
|
||||
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
|
||||
os << ", ";
|
||||
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
|
||||
}
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
|
||||
// Location attributes.
|
||||
case StandardAttributes::CallSiteLocation:
|
||||
case StandardAttributes::FileLineColLocation:
|
||||
case StandardAttributes::FusedLocation:
|
||||
case StandardAttributes::NameLocation:
|
||||
case StandardAttributes::OpaqueLocation:
|
||||
case StandardAttributes::UnknownLocation:
|
||||
printLocation(attr.cast<LocationAttr>());
|
||||
break;
|
||||
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
|
||||
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
} else {
|
||||
os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
|
||||
os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
|
||||
}
|
||||
|
||||
} else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
|
||||
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
} else {
|
||||
os << "dense<";
|
||||
printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
|
||||
os << '>';
|
||||
}
|
||||
|
||||
} else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
|
||||
if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
|
||||
printElidedElementsAttr(os);
|
||||
} else {
|
||||
os << "dense<";
|
||||
printDenseStringElementsAttr(strEltAttr);
|
||||
os << '>';
|
||||
}
|
||||
|
||||
} else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
|
||||
if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
|
||||
printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
|
||||
printElidedElementsAttr(os);
|
||||
} else {
|
||||
os << "sparse<";
|
||||
DenseIntElementsAttr indices = sparseEltAttr.getIndices();
|
||||
if (indices.getNumElements() != 0) {
|
||||
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
|
||||
os << ", ";
|
||||
printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
|
||||
}
|
||||
os << '>';
|
||||
}
|
||||
|
||||
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
|
||||
printLocation(locAttr);
|
||||
|
||||
} else {
|
||||
return printDialectAttribute(attr);
|
||||
}
|
||||
|
||||
// Don't print the type if we must elide it, or if it is a None type.
|
||||
|
||||
@@ -460,16 +460,11 @@ int64_t ElementsAttr::getNumElements() const {
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
/// element, then a null attribute is returned.
|
||||
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntOrFPElements:
|
||||
return cast<DenseElementsAttr>().getValue(index);
|
||||
case StandardAttributes::OpaqueElements:
|
||||
return cast<OpaqueElementsAttr>().getValue(index);
|
||||
case StandardAttributes::SparseElements:
|
||||
return cast<SparseElementsAttr>().getValue(index);
|
||||
default:
|
||||
llvm_unreachable("unknown ElementsAttr kind");
|
||||
}
|
||||
if (auto denseAttr = dyn_cast<DenseElementsAttr>())
|
||||
return denseAttr.getValue(index);
|
||||
if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
|
||||
return opaqueAttr.getValue(index);
|
||||
return cast<SparseElementsAttr>().getValue(index);
|
||||
}
|
||||
|
||||
/// Return if the given 'index' refers to a valid element in this attribute.
|
||||
@@ -491,23 +486,23 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
|
||||
ElementsAttr
|
||||
ElementsAttr::mapValues(Type newElementType,
|
||||
function_ref<APInt(const APInt &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntOrFPElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
default:
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
|
||||
return intOrFpAttr.mapValues(newElementType, mapping);
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
|
||||
ElementsAttr
|
||||
ElementsAttr::mapValues(Type newElementType,
|
||||
function_ref<APInt(const APFloat &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntOrFPElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
default:
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
|
||||
return intOrFpAttr.mapValues(newElementType, mapping);
|
||||
llvm_unreachable("unsupported ElementsAttr subtype");
|
||||
}
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
bool ElementsAttr::classof(Attribute attr) {
|
||||
return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
|
||||
OpaqueElementsAttr, SparseElementsAttr>();
|
||||
}
|
||||
|
||||
/// Returns the 1 dimensional flattened row-major index from the given
|
||||
@@ -718,6 +713,11 @@ DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
|
||||
// DenseElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
bool DenseElementsAttr::classof(Attribute attr) {
|
||||
return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
assert(hasSameElementsOrSplat(type, values));
|
||||
|
||||
@@ -366,43 +366,38 @@ struct SourceMgrDiagnosticHandlerImpl {
|
||||
|
||||
/// Return a processable FileLineColLoc from the given location.
|
||||
static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
|
||||
switch (loc->getKind()) {
|
||||
case StandardAttributes::NameLocation:
|
||||
if (auto nameLoc = loc.dyn_cast<NameLoc>())
|
||||
return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
|
||||
case StandardAttributes::FileLineColLocation:
|
||||
return loc.cast<FileLineColLoc>();
|
||||
case StandardAttributes::CallSiteLocation:
|
||||
// Process the callee of a callsite location.
|
||||
if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
|
||||
return fileLoc;
|
||||
if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
|
||||
return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
|
||||
case StandardAttributes::FusedLocation:
|
||||
if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
|
||||
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
|
||||
if (auto callLoc = getFileLineColLoc(subLoc)) {
|
||||
return callLoc;
|
||||
}
|
||||
}
|
||||
return llvm::None;
|
||||
default:
|
||||
return llvm::None;
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Return a processable CallSiteLoc from the given location.
|
||||
static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
|
||||
switch (loc->getKind()) {
|
||||
case StandardAttributes::NameLocation:
|
||||
if (auto nameLoc = loc.dyn_cast<NameLoc>())
|
||||
return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
|
||||
case StandardAttributes::CallSiteLocation:
|
||||
return loc.cast<CallSiteLoc>();
|
||||
case StandardAttributes::FusedLocation:
|
||||
if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
|
||||
return callLoc;
|
||||
if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
|
||||
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
|
||||
if (auto callLoc = getCallSiteLoc(subLoc)) {
|
||||
return callLoc;
|
||||
}
|
||||
}
|
||||
return llvm::None;
|
||||
default:
|
||||
return llvm::None;
|
||||
}
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
/// Given a diagnostic kind, returns the LLVM DiagKind.
|
||||
|
||||
@@ -13,6 +13,16 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LocationAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
bool LocationAttr::classof(Attribute attr) {
|
||||
return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
|
||||
UnknownLoc>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CallSiteLoc
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -115,26 +115,19 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
|
||||
return existingIt->second;
|
||||
|
||||
const llvm::DILocation *llvmLoc = nullptr;
|
||||
switch (loc->getKind()) {
|
||||
case StandardAttributes::CallSiteLocation: {
|
||||
auto callLoc = loc.dyn_cast<CallSiteLoc>();
|
||||
|
||||
if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
|
||||
// For callsites, the caller is fed as the inlinedAt for the callee.
|
||||
const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt);
|
||||
llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc);
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::FileLineColLocation: {
|
||||
auto fileLoc = loc.dyn_cast<FileLineColLoc>();
|
||||
|
||||
} else if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
|
||||
auto *file = translateFile(fileLoc.getFilename());
|
||||
auto *fileScope = builder.createLexicalBlockFile(scope, file);
|
||||
llvmLoc = llvm::DILocation::get(llvmCtx, fileLoc.getLine(),
|
||||
fileLoc.getColumn(), fileScope,
|
||||
const_cast<llvm::DILocation *>(inlinedAt));
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::FusedLocation: {
|
||||
auto fusedLoc = loc.dyn_cast<FusedLoc>();
|
||||
|
||||
} else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
|
||||
ArrayRef<Location> locations = fusedLoc.getLocations();
|
||||
|
||||
// For fused locations, merge each of the nodes.
|
||||
@@ -143,18 +136,17 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
|
||||
llvmLoc = llvm::DILocation::getMergedLocation(
|
||||
llvmLoc, translateLoc(locIt, scope, inlinedAt));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::NameLocation:
|
||||
|
||||
} else if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
|
||||
llvmLoc = translateLoc(loc.cast<NameLoc>().getChildLoc(), scope, inlinedAt);
|
||||
break;
|
||||
case StandardAttributes::OpaqueLocation:
|
||||
|
||||
} else if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>()) {
|
||||
llvmLoc = translateLoc(loc.cast<OpaqueLoc>().getFallbackLocation(), scope,
|
||||
inlinedAt);
|
||||
break;
|
||||
default:
|
||||
} else {
|
||||
llvm_unreachable("unknown location kind");
|
||||
}
|
||||
|
||||
locationToLoc.try_emplace(std::make_pair(loc, scope), llvmLoc);
|
||||
return llvmLoc;
|
||||
}
|
||||
|
||||
@@ -158,7 +158,7 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
|
||||
auto expectedTensorType = realValue.getType().cast<TensorType>();
|
||||
EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
|
||||
EXPECT_EQ(tensorType.getElementType(), convertedType);
|
||||
EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
|
||||
EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
|
||||
|
||||
// Check Elements attribute element value is expected.
|
||||
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
|
||||
|
||||
Reference in New Issue
Block a user