[mlir][Attribute] Remove usages of Attribute::getKind

This is in preparation for removing the use of "kinds" within attributes and types in MLIR.

Differential Revision: https://reviews.llvm.org/D85370
This commit is contained in:
River Riddle
2020-08-07 13:30:29 -07:00
parent 1d6a8deb41
commit fff39b62bb
9 changed files with 219 additions and 277 deletions

View File

@@ -661,10 +661,7 @@ public:
function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_ELEMENTS_ATTR &&
attr.getKind() <= StandardAttributes::LAST_ELEMENTS_ATTR;
}
static bool classof(Attribute attr);
protected:
/// Returns the 1 dimensional flattened row-major index from the given
@@ -729,10 +726,7 @@ public:
using ElementsAttr::ElementsAttr;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() == StandardAttributes::DenseIntOrFPElements ||
attr.getKind() == StandardAttributes::DenseStringElements;
}
static bool classof(Attribute attr);
/// Constructs a dense elements attribute from an array of element values.
/// Each element attribute value is expected to be an element of 'type'.
@@ -1513,12 +1507,10 @@ class ElementsAttrIterator
template <typename RetT, template <typename> class ProcessFn,
typename... Args>
RetT process(Args &... args) const {
switch (attrKind) {
case StandardAttributes::DenseIntOrFPElements:
if (attr.isa<DenseElementsAttr>())
return ProcessFn<DenseIteratorT>()(args...);
case StandardAttributes::SparseElements:
if (attr.isa<SparseElementsAttr>())
return ProcessFn<SparseIteratorT>()(args...);
}
llvm_unreachable("unexpected attribute kind");
}
@@ -1543,22 +1535,21 @@ class ElementsAttrIterator
};
public:
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs)
: attrKind(rhs.attrKind) {
ElementsAttrIterator(const ElementsAttrIterator<T> &rhs) : attr(rhs.attr) {
process<void, ConstructIter>(it, rhs.it);
}
~ElementsAttrIterator() { process<void, DestructIter>(it); }
/// Methods necessary to support random access iteration.
ptrdiff_t operator-(const ElementsAttrIterator<T> &rhs) const {
assert(attrKind == rhs.attrKind && "incompatible iterators");
assert(attr == rhs.attr && "incompatible iterators");
return process<ptrdiff_t, Minus>(it, rhs.it);
}
bool operator==(const ElementsAttrIterator<T> &rhs) const {
return rhs.attrKind == attrKind && process<bool, std::equal_to>(it, rhs.it);
return rhs.attr == attr && process<bool, std::equal_to>(it, rhs.it);
}
bool operator<(const ElementsAttrIterator<T> &rhs) const {
assert(attrKind == rhs.attrKind && "incompatible iterators");
assert(attr == rhs.attr && "incompatible iterators");
return process<bool, std::less>(it, rhs.it);
}
ElementsAttrIterator<T> &operator+=(ptrdiff_t offset) {
@@ -1575,14 +1566,14 @@ public:
private:
template <typename IteratorT>
ElementsAttrIterator(unsigned attrKind, IteratorT &&it)
: attrKind(attrKind), it(std::forward<IteratorT>(it)) {}
ElementsAttrIterator(Attribute attr, IteratorT &&it)
: attr(attr), it(std::forward<IteratorT>(it)) {}
/// Allow accessing the constructor.
friend ElementsAttr;
/// The kind of derived elements attribute.
unsigned attrKind;
/// The parent elements attribute.
Attribute attr;
/// A union containing the specific iterators for each derived kind.
Iterator it;
@@ -1599,13 +1590,13 @@ template <typename T>
auto ElementsAttr::getValues() const -> iterator_range<T> {
if (DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>()) {
auto values = denseAttr.getValues<T>();
return {iterator<T>(getKind(), values.begin()),
iterator<T>(getKind(), values.end())};
return {iterator<T>(*this, values.begin()),
iterator<T>(*this, values.end())};
}
if (SparseElementsAttr sparseAttr = dyn_cast<SparseElementsAttr>()) {
auto values = sparseAttr.getValues<T>();
return {iterator<T>(getKind(), values.begin()),
iterator<T>(getKind(), values.end())};
return {iterator<T>(*this, values.begin()),
iterator<T>(*this, values.end())};
}
llvm_unreachable("unexpected attribute kind");
}

View File

@@ -42,10 +42,7 @@ public:
using Attribute::Attribute;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr) {
return attr.getKind() >= StandardAttributes::FIRST_LOCATION_ATTR &&
attr.getKind() <= StandardAttributes::LAST_LOCATION_ATTR;
}
static bool classof(Attribute attr);
};
/// This class defines the main interface for locations in MLIR and acts as a

View File

@@ -1472,18 +1472,15 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
// ODS already generates checks to make sure the result type is valid. We just
// need to additionally check that the value's attribute type is consistent
// with the result type.
switch (value.getKind()) {
case StandardAttributes::Integer:
case StandardAttributes::Float: {
if (value.isa<IntegerAttr, FloatAttr>()) {
if (valueType != opType)
return constOp.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
} break;
case StandardAttributes::DenseIntOrFPElements:
case StandardAttributes::SparseElements: {
}
if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
if (valueType == opType)
break;
return success();
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
auto shapedType = valueType.dyn_cast<ShapedType>();
if (!arrayType) {
@@ -1497,9 +1494,8 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
if (!opElemType.isIntOrFloat()) {
if (!opElemType.isIntOrFloat())
return constOp.emitOpError("only support nested array result type");
}
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
@@ -1513,26 +1509,24 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
} break;
case StandardAttributes::Array: {
return success();
}
if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
if (!arrayType)
return constOp.emitOpError(
"must have spv.array result type for array value");
auto elemType = arrayType.getElementType();
for (auto element : value.cast<ArrayAttr>().getValue()) {
Type elemType = arrayType.getElementType();
for (Attribute element : attayAttr.getValue()) {
if (element.getType() != elemType)
return constOp.emitOpError("has array element whose type (")
<< element.getType()
<< ") does not match the result element type (" << elemType
<< ')';
}
} break;
default:
return constOp.emitOpError("cannot have value of type ") << valueType;
return success();
}
return success();
return constOp.emitOpError("cannot have value of type ") << valueType;
}
bool spirv::ConstantOp::isBuildableWith(Type type) {
@@ -2619,19 +2613,14 @@ static LogicalResult verify(spirv::SpecConstantOp constOp) {
return constOp.emitOpError("SpecId cannot be negative");
auto value = constOp.default_value();
switch (value.getKind()) {
case StandardAttributes::Integer:
case StandardAttributes::Float: {
if (value.isa<IntegerAttr, FloatAttr>()) {
// Make sure bitwidth is allowed.
if (!value.getType().isa<spirv::SPIRVType>())
return constOp.emitOpError("default value bitwidth disallowed");
return success();
}
default:
return constOp.emitOpError(
"default value can only be a bool, integer, or float scalar");
}
return constOp.emitOpError(
"default value can only be a bool, integer, or float scalar");
}
//===----------------------------------------------------------------------===//

View File

@@ -33,6 +33,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/SaveAndRestore.h"
@@ -1019,76 +1020,67 @@ void ModulePrinter::printTrailingLocation(Location loc) {
}
void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
switch (loc.getKind()) {
case StandardAttributes::OpaqueLocation:
printLocationInternal(loc.cast<OpaqueLoc>().getFallbackLocation(), pretty);
break;
case StandardAttributes::UnknownLocation:
if (pretty)
os << "[unknown]";
else
os << "unknown";
break;
case StandardAttributes::FileLineColLocation: {
auto fileLoc = loc.cast<FileLineColLoc>();
auto mayQuote = pretty ? "" : "\"";
os << mayQuote << fileLoc.getFilename() << mayQuote << ':'
<< fileLoc.getLine() << ':' << fileLoc.getColumn();
break;
}
case StandardAttributes::NameLocation: {
auto nameLoc = loc.cast<NameLoc>();
os << '\"' << nameLoc.getName() << '\"';
TypeSwitch<LocationAttr>(loc)
.Case<OpaqueLoc>([&](OpaqueLoc loc) {
printLocationInternal(loc.getFallbackLocation(), pretty);
})
.Case<UnknownLoc>([&](UnknownLoc loc) {
if (pretty)
os << "[unknown]";
else
os << "unknown";
})
.Case<FileLineColLoc>([&](FileLineColLoc loc) {
StringRef mayQuote = pretty ? "" : "\"";
os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
<< ':' << loc.getColumn();
})
.Case<NameLoc>([&](NameLoc loc) {
os << '\"' << loc.getName() << '\"';
// Print the child if it isn't unknown.
auto childLoc = nameLoc.getChildLoc();
if (!childLoc.isa<UnknownLoc>()) {
os << '(';
printLocationInternal(childLoc, pretty);
os << ')';
}
break;
}
case StandardAttributes::CallSiteLocation: {
auto callLocation = loc.cast<CallSiteLoc>();
auto caller = callLocation.getCaller();
auto callee = callLocation.getCallee();
if (!pretty)
os << "callsite(";
printLocationInternal(callee, pretty);
if (pretty) {
if (callee.isa<NameLoc>()) {
if (caller.isa<FileLineColLoc>()) {
os << " at ";
} else {
os << newLine << " at ";
// Print the child if it isn't unknown.
auto childLoc = loc.getChildLoc();
if (!childLoc.isa<UnknownLoc>()) {
os << '(';
printLocationInternal(childLoc, pretty);
os << ')';
}
} else {
os << newLine << " at ";
}
} else {
os << " at ";
}
printLocationInternal(caller, pretty);
if (!pretty)
os << ")";
break;
}
case StandardAttributes::FusedLocation: {
auto fusedLoc = loc.cast<FusedLoc>();
if (!pretty)
os << "fused";
if (auto metadata = fusedLoc.getMetadata())
os << '<' << metadata << '>';
os << '[';
interleave(
fusedLoc.getLocations(),
[&](Location loc) { printLocationInternal(loc, pretty); },
[&]() { os << ", "; });
os << ']';
break;
}
}
})
.Case<CallSiteLoc>([&](CallSiteLoc loc) {
Location caller = loc.getCaller();
Location callee = loc.getCallee();
if (!pretty)
os << "callsite(";
printLocationInternal(callee, pretty);
if (pretty) {
if (callee.isa<NameLoc>()) {
if (caller.isa<FileLineColLoc>()) {
os << " at ";
} else {
os << newLine << " at ";
}
} else {
os << newLine << " at ";
}
} else {
os << " at ";
}
printLocationInternal(caller, pretty);
if (!pretty)
os << ")";
})
.Case<FusedLoc>([&](FusedLoc loc) {
if (!pretty)
os << "fused";
if (Attribute metadata = loc.getMetadata())
os << '<' << metadata << '>';
os << '[';
interleave(
loc.getLocations(),
[&](Location loc) { printLocationInternal(loc, pretty); },
[&]() { os << ", "; });
os << ']';
});
}
/// Print a floating point value in a way that the parser will be able to
@@ -1305,27 +1297,19 @@ void ModulePrinter::printAttribute(Attribute attr,
}
auto attrType = attr.getType();
switch (attr.getKind()) {
default:
return printDialectAttribute(attr);
case StandardAttributes::Opaque: {
auto opaqueAttr = attr.cast<OpaqueAttr>();
if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
opaqueAttr.getAttrData());
break;
}
case StandardAttributes::Unit:
} else if (attr.isa<UnitAttr>()) {
os << "unit";
break;
case StandardAttributes::Dictionary:
return;
} else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
os << '{';
interleaveComma(attr.cast<DictionaryAttr>().getValue(),
interleaveComma(dictAttr.getValue(),
[&](NamedAttribute attr) { printNamedAttribute(attr); });
os << '}';
break;
case StandardAttributes::Integer: {
auto intAttr = attr.cast<IntegerAttr>();
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
if (attrType.isSignlessInteger(1)) {
os << (intAttr.getValue().getBoolValue() ? "true" : "false");
@@ -1343,114 +1327,98 @@ void ModulePrinter::printAttribute(Attribute attr,
// IntegerAttr elides the type if I64.
if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
return;
break;
}
case StandardAttributes::Float: {
auto floatAttr = attr.cast<FloatAttr>();
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
printFloatValue(floatAttr.getValue(), os);
// FloatAttr elides the type if F64.
if (typeElision == AttrTypeElision::May && attrType.isF64())
return;
break;
}
case StandardAttributes::String:
} else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
os << '"';
printEscapedString(attr.cast<StringAttr>().getValue(), os);
printEscapedString(strAttr.getValue(), os);
os << '"';
break;
case StandardAttributes::Array:
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
os << '[';
interleaveComma(attr.cast<ArrayAttr>().getValue(), [&](Attribute attr) {
interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
printAttribute(attr, AttrTypeElision::May);
});
os << ']';
break;
case StandardAttributes::AffineMap:
} else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
os << "affine_map<";
attr.cast<AffineMapAttr>().getValue().print(os);
affineMapAttr.getValue().print(os);
os << '>';
// AffineMap always elides the type.
return;
case StandardAttributes::IntegerSet:
} else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
os << "affine_set<";
attr.cast<IntegerSetAttr>().getValue().print(os);
integerSetAttr.getValue().print(os);
os << '>';
// IntegerSet always elides the type.
return;
case StandardAttributes::Type:
printType(attr.cast<TypeAttr>().getValue());
break;
case StandardAttributes::SymbolRef: {
auto refAttr = attr.dyn_cast<SymbolRefAttr>();
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
printType(typeAttr.getValue());
} else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
printSymbolReference(refAttr.getRootReference(), os);
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
os << "::";
printSymbolReference(nestedRef.getValue(), os);
}
break;
}
case StandardAttributes::OpaqueElements: {
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
break;
}
os << "opaque<\"" << eltsAttr.getDialect()->getNamespace() << "\", ";
os << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << "\">";
break;
}
case StandardAttributes::DenseIntOrFPElements: {
auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
break;
}
os << "dense<";
printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
os << '>';
break;
}
case StandardAttributes::DenseStringElements: {
auto eltsAttr = attr.cast<DenseStringElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
break;
}
os << "dense<";
printDenseStringElementsAttr(eltsAttr);
os << '>';
break;
}
case StandardAttributes::SparseElements: {
auto elementsAttr = attr.cast<SparseElementsAttr>();
if (printerFlags.shouldElideElementsAttr(elementsAttr.getIndices()) ||
printerFlags.shouldElideElementsAttr(elementsAttr.getValues())) {
printElidedElementsAttr(os);
break;
}
os << "sparse<";
DenseIntElementsAttr indices = elementsAttr.getIndices();
if (indices.getNumElements() != 0) {
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
os << ", ";
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
}
os << '>';
break;
}
// Location attributes.
case StandardAttributes::CallSiteLocation:
case StandardAttributes::FileLineColLocation:
case StandardAttributes::FusedLocation:
case StandardAttributes::NameLocation:
case StandardAttributes::OpaqueLocation:
case StandardAttributes::UnknownLocation:
printLocation(attr.cast<LocationAttr>());
break;
} else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
printElidedElementsAttr(os);
} else {
os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
}
} else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
printElidedElementsAttr(os);
} else {
os << "dense<";
printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
os << '>';
}
} else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
printElidedElementsAttr(os);
} else {
os << "dense<";
printDenseStringElementsAttr(strEltAttr);
os << '>';
}
} else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
printElidedElementsAttr(os);
} else {
os << "sparse<";
DenseIntElementsAttr indices = sparseEltAttr.getIndices();
if (indices.getNumElements() != 0) {
printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
os << ", ";
printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
}
os << '>';
}
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {
return printDialectAttribute(attr);
}
// Don't print the type if we must elide it, or if it is a None type.

View File

@@ -460,16 +460,11 @@ int64_t ElementsAttr::getNumElements() const {
/// Return the value at the given index. If index does not refer to a valid
/// element, then a null attribute is returned.
Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
switch (getKind()) {
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().getValue(index);
case StandardAttributes::OpaqueElements:
return cast<OpaqueElementsAttr>().getValue(index);
case StandardAttributes::SparseElements:
return cast<SparseElementsAttr>().getValue(index);
default:
llvm_unreachable("unknown ElementsAttr kind");
}
if (auto denseAttr = dyn_cast<DenseElementsAttr>())
return denseAttr.getValue(index);
if (auto opaqueAttr = dyn_cast<OpaqueElementsAttr>())
return opaqueAttr.getValue(index);
return cast<SparseElementsAttr>().getValue(index);
}
/// Return if the given 'index' refers to a valid element in this attribute.
@@ -491,23 +486,23 @@ bool ElementsAttr::isValidIndex(ArrayRef<uint64_t> index) const {
ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APInt &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
default:
llvm_unreachable("unsupported ElementsAttr subtype");
}
if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
return intOrFpAttr.mapValues(newElementType, mapping);
llvm_unreachable("unsupported ElementsAttr subtype");
}
ElementsAttr
ElementsAttr::mapValues(Type newElementType,
function_ref<APInt(const APFloat &)> mapping) const {
switch (getKind()) {
case StandardAttributes::DenseIntOrFPElements:
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
default:
llvm_unreachable("unsupported ElementsAttr subtype");
}
if (auto intOrFpAttr = dyn_cast<DenseElementsAttr>())
return intOrFpAttr.mapValues(newElementType, mapping);
llvm_unreachable("unsupported ElementsAttr subtype");
}
/// Method for support type inquiry through isa, cast and dyn_cast.
bool ElementsAttr::classof(Attribute attr) {
return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr,
OpaqueElementsAttr, SparseElementsAttr>();
}
/// Returns the 1 dimensional flattened row-major index from the given
@@ -718,6 +713,11 @@ DenseElementsAttr::ComplexFloatElementIterator::ComplexFloatElementIterator(
// DenseElementsAttr
//===----------------------------------------------------------------------===//
/// Method for support type inquiry through isa, cast and dyn_cast.
bool DenseElementsAttr::classof(Attribute attr) {
return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
}
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
ArrayRef<Attribute> values) {
assert(hasSameElementsOrSplat(type, values));

View File

@@ -366,43 +366,38 @@ struct SourceMgrDiagnosticHandlerImpl {
/// Return a processable FileLineColLoc from the given location.
static Optional<FileLineColLoc> getFileLineColLoc(Location loc) {
switch (loc->getKind()) {
case StandardAttributes::NameLocation:
if (auto nameLoc = loc.dyn_cast<NameLoc>())
return getFileLineColLoc(loc.cast<NameLoc>().getChildLoc());
case StandardAttributes::FileLineColLocation:
return loc.cast<FileLineColLoc>();
case StandardAttributes::CallSiteLocation:
// Process the callee of a callsite location.
if (auto fileLoc = loc.dyn_cast<FileLineColLoc>())
return fileLoc;
if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
return getFileLineColLoc(loc.cast<CallSiteLoc>().getCallee());
case StandardAttributes::FusedLocation:
if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
if (auto callLoc = getFileLineColLoc(subLoc)) {
return callLoc;
}
}
return llvm::None;
default:
return llvm::None;
}
return llvm::None;
}
/// Return a processable CallSiteLoc from the given location.
static Optional<CallSiteLoc> getCallSiteLoc(Location loc) {
switch (loc->getKind()) {
case StandardAttributes::NameLocation:
if (auto nameLoc = loc.dyn_cast<NameLoc>())
return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc());
case StandardAttributes::CallSiteLocation:
return loc.cast<CallSiteLoc>();
case StandardAttributes::FusedLocation:
if (auto callLoc = loc.dyn_cast<CallSiteLoc>())
return callLoc;
if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
for (auto subLoc : loc.cast<FusedLoc>().getLocations()) {
if (auto callLoc = getCallSiteLoc(subLoc)) {
return callLoc;
}
}
return llvm::None;
default:
return llvm::None;
}
return llvm::None;
}
/// Given a diagnostic kind, returns the LLVM DiagKind.

View File

@@ -13,6 +13,16 @@
using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
// LocationAttr
//===----------------------------------------------------------------------===//
/// Methods for support type inquiry through isa, cast, and dyn_cast.
bool LocationAttr::classof(Attribute attr) {
return attr.isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>();
}
//===----------------------------------------------------------------------===//
// CallSiteLoc
//===----------------------------------------------------------------------===//

View File

@@ -115,26 +115,19 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
return existingIt->second;
const llvm::DILocation *llvmLoc = nullptr;
switch (loc->getKind()) {
case StandardAttributes::CallSiteLocation: {
auto callLoc = loc.dyn_cast<CallSiteLoc>();
if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) {
// For callsites, the caller is fed as the inlinedAt for the callee.
const auto *callerLoc = translateLoc(callLoc.getCaller(), scope, inlinedAt);
llvmLoc = translateLoc(callLoc.getCallee(), scope, callerLoc);
break;
}
case StandardAttributes::FileLineColLocation: {
auto fileLoc = loc.dyn_cast<FileLineColLoc>();
} else if (auto fileLoc = loc.dyn_cast<FileLineColLoc>()) {
auto *file = translateFile(fileLoc.getFilename());
auto *fileScope = builder.createLexicalBlockFile(scope, file);
llvmLoc = llvm::DILocation::get(llvmCtx, fileLoc.getLine(),
fileLoc.getColumn(), fileScope,
const_cast<llvm::DILocation *>(inlinedAt));
break;
}
case StandardAttributes::FusedLocation: {
auto fusedLoc = loc.dyn_cast<FusedLoc>();
} else if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) {
ArrayRef<Location> locations = fusedLoc.getLocations();
// For fused locations, merge each of the nodes.
@@ -143,18 +136,17 @@ DebugTranslation::translateLoc(Location loc, llvm::DILocalScope *scope,
llvmLoc = llvm::DILocation::getMergedLocation(
llvmLoc, translateLoc(locIt, scope, inlinedAt));
}
break;
}
case StandardAttributes::NameLocation:
} else if (auto nameLoc = loc.dyn_cast<NameLoc>()) {
llvmLoc = translateLoc(loc.cast<NameLoc>().getChildLoc(), scope, inlinedAt);
break;
case StandardAttributes::OpaqueLocation:
} else if (auto opaqueLoc = loc.dyn_cast<OpaqueLoc>()) {
llvmLoc = translateLoc(loc.cast<OpaqueLoc>().getFallbackLocation(), scope,
inlinedAt);
break;
default:
} else {
llvm_unreachable("unknown location kind");
}
locationToLoc.try_emplace(std::make_pair(loc, scope), llvmLoc);
return llvmLoc;
}

View File

@@ -158,7 +158,7 @@ TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
auto expectedTensorType = realValue.getType().cast<TensorType>();
EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
EXPECT_EQ(tensorType.getElementType(), convertedType);
EXPECT_EQ(returnedValue.getKind(), StandardAttributes::SparseElements);
EXPECT_TRUE(returnedValue.isa<SparseElementsAttr>());
// Check Elements attribute element value is expected.
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});