Improve MLIR "view-op-graph" to color operations according to their name

Differential Revision: https://reviews.llvm.org/D153290
This commit is contained in:
Mehdi Amini
2023-06-19 14:35:18 +02:00
parent 4eeb3647d1
commit 6eca120dd8

View File

@@ -87,6 +87,7 @@ public:
PrintOpPass(const PrintOpPass &o) : PrintOpPass(o.os.getOStream()) {}
void runOnOperation() override {
initColorMapping(*getOperation());
emitGraph([&]() {
processOperation(getOperation());
emitAllEdgeStmts();
@@ -97,10 +98,31 @@ public:
void emitRegionCFG(Region &region) {
printControlFlowEdges = true;
printDataFlowEdges = false;
initColorMapping(region);
emitGraph([&]() { processRegion(region); });
}
private:
/// Generate a color mapping that will color every operation with the same
/// name the same way. It'll interpolate the hue in the HSV color-space,
/// attempting to keep the contrast suitable for black text.
template <typename T>
void initColorMapping(T &irEntity) {
backgroundColors.clear();
SmallVector<Operation *> ops;
irEntity.walk([&](Operation *op) {
auto &entry = backgroundColors[op->getName()];
if (entry.first == 0)
ops.push_back(op);
++entry.first;
});
for (auto indexedOps : llvm::enumerate(ops)) {
double hue = ((double)indexedOps.index()) / ops.size();
backgroundColors[indexedOps.value()->getName()].second =
std::to_string(hue) + " 1.0 1.0";
}
}
/// Emit all edges. This function should be called after all nodes have been
/// emitted.
void emitAllEdgeStmts() {
@@ -206,11 +228,16 @@ private:
}
/// Emit a node statement.
Node emitNodeStmt(std::string label, StringRef shape = kShapeNode) {
Node emitNodeStmt(std::string label, StringRef shape = kShapeNode,
StringRef background = "") {
int nodeId = ++counter;
AttributeMap attrs;
attrs["label"] = quoteString(escapeString(std::move(label)));
attrs["shape"] = shape.str();
if (!background.empty()) {
attrs["style"] = "filled";
attrs["fillcolor"] = ("\"" + background + "\"").str();
}
os << llvm::format("v%i ", nodeId);
emitAttrList(os, attrs);
os << ";\n";
@@ -278,7 +305,8 @@ private:
},
getLabel(op));
} else {
node = emitNodeStmt(getLabel(op));
node = emitNodeStmt(getLabel(op), kShapeNode,
backgroundColors[op->getName()].second);
}
// Insert data flow edges originating from each operand.
@@ -318,6 +346,8 @@ private:
DenseMap<Value, Node> valueToNode;
/// Counter for generating unique node/subgraph identifiers.
int counter = 0;
DenseMap<OperationName, std::pair<int, std::string>> backgroundColors;
};
} // namespace