2018-11-20 09:38:15 -08:00
|
|
|
//===- OpStats.cpp - Prints stats of operations in module -----------------===//
|
|
|
|
|
//
|
2020-01-26 03:58:30 +00:00
|
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
2019-12-23 09:35:36 -08:00
|
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
|
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
2018-11-20 09:38:15 -08:00
|
|
|
//
|
2019-12-23 09:35:36 -08:00
|
|
|
//===----------------------------------------------------------------------===//
|
2018-11-20 09:38:15 -08:00
|
|
|
|
2020-04-07 13:58:12 -07:00
|
|
|
#include "PassDetail.h"
|
2018-12-29 15:33:43 -08:00
|
|
|
#include "mlir/IR/Module.h"
|
2019-03-26 14:45:38 -07:00
|
|
|
#include "mlir/IR/Operation.h"
|
2018-11-20 09:38:15 -08:00
|
|
|
#include "mlir/IR/OperationSupport.h"
|
2020-02-12 09:03:40 +00:00
|
|
|
#include "mlir/Transforms/Passes.h"
|
2018-11-20 09:38:15 -08:00
|
|
|
#include "llvm/ADT/DenseMap.h"
|
Rewrite OpStats to use llvm formatting utilities.
Example Output:
Operations encountered:
-----------------------
addf , 11
constant , 4
return , 19
some_op , 1
tf.AvgPool , 3
tf.DepthwiseConv2dNative , 3
tf.FusedBatchNorm , 2
tfl.add , 7
tfl.average_pool_2d , 1
tfl.leaky_relu , 1
PiperOrigin-RevId: 229937190
2019-01-18 09:00:34 -08:00
|
|
|
#include "llvm/Support/Format.h"
|
2018-11-20 09:38:15 -08:00
|
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
|
|
|
|
|
|
using namespace mlir;
|
|
|
|
|
|
|
|
|
|
namespace {
|
2020-04-07 13:58:12 -07:00
|
|
|
struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> {
|
2019-12-18 09:28:48 -08:00
|
|
|
explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {}
|
2018-11-20 09:38:15 -08:00
|
|
|
|
2018-12-30 10:43:03 -08:00
|
|
|
// Prints the resultant operation statistics post iterating over the module.
|
2020-04-07 13:55:34 -07:00
|
|
|
void runOnOperation() override;
|
2018-11-20 09:38:15 -08:00
|
|
|
|
|
|
|
|
// Print summary of op stats.
|
|
|
|
|
void printSummary();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
llvm::StringMap<int64_t> opCount;
|
2019-12-18 09:28:48 -08:00
|
|
|
raw_ostream &os;
|
2018-11-20 09:38:15 -08:00
|
|
|
};
|
|
|
|
|
} // namespace
|
|
|
|
|
|
2020-04-07 13:55:34 -07:00
|
|
|
void PrintOpStatsPass::runOnOperation() {
|
2019-02-04 16:24:44 -08:00
|
|
|
opCount.clear();
|
|
|
|
|
|
|
|
|
|
// Compute the operation statistics for each function in the module.
|
2020-04-07 13:55:34 -07:00
|
|
|
for (auto &op : getOperation())
|
2019-07-08 18:27:45 -07:00
|
|
|
op.walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; });
|
2018-12-29 15:33:43 -08:00
|
|
|
printSummary();
|
2018-11-20 09:38:15 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PrintOpStatsPass::printSummary() {
|
|
|
|
|
os << "Operations encountered:\n";
|
|
|
|
|
os << "-----------------------\n";
|
2019-05-29 11:59:26 -07:00
|
|
|
SmallVector<StringRef, 64> sorted(opCount.keys());
|
2018-11-20 09:38:15 -08:00
|
|
|
llvm::sort(sorted);
|
|
|
|
|
|
Rewrite OpStats to use llvm formatting utilities.
Example Output:
Operations encountered:
-----------------------
addf , 11
constant , 4
return , 19
some_op , 1
tf.AvgPool , 3
tf.DepthwiseConv2dNative , 3
tf.FusedBatchNorm , 2
tfl.add , 7
tfl.average_pool_2d , 1
tfl.leaky_relu , 1
PiperOrigin-RevId: 229937190
2019-01-18 09:00:34 -08:00
|
|
|
// Split an operation name from its dialect prefix.
|
|
|
|
|
auto splitOperationName = [](StringRef opName) {
|
|
|
|
|
auto splitName = opName.split('.');
|
|
|
|
|
return splitName.second.empty() ? std::make_pair("", splitName.first)
|
|
|
|
|
: splitName;
|
2018-11-20 09:38:15 -08:00
|
|
|
};
|
|
|
|
|
|
Rewrite OpStats to use llvm formatting utilities.
Example Output:
Operations encountered:
-----------------------
addf , 11
constant , 4
return , 19
some_op , 1
tf.AvgPool , 3
tf.DepthwiseConv2dNative , 3
tf.FusedBatchNorm , 2
tfl.add , 7
tfl.average_pool_2d , 1
tfl.leaky_relu , 1
PiperOrigin-RevId: 229937190
2019-01-18 09:00:34 -08:00
|
|
|
// Compute the largest dialect and operation name.
|
|
|
|
|
StringRef dialectName, opName;
|
|
|
|
|
size_t maxLenOpName = 0, maxLenDialect = 0;
|
2018-11-20 09:38:15 -08:00
|
|
|
for (const auto &key : sorted) {
|
Rewrite OpStats to use llvm formatting utilities.
Example Output:
Operations encountered:
-----------------------
addf , 11
constant , 4
return , 19
some_op , 1
tf.AvgPool , 3
tf.DepthwiseConv2dNative , 3
tf.FusedBatchNorm , 2
tfl.add , 7
tfl.average_pool_2d , 1
tfl.leaky_relu , 1
PiperOrigin-RevId: 229937190
2019-01-18 09:00:34 -08:00
|
|
|
std::tie(dialectName, opName) = splitOperationName(key);
|
|
|
|
|
maxLenDialect = std::max(maxLenDialect, dialectName.size());
|
|
|
|
|
maxLenOpName = std::max(maxLenOpName, opName.size());
|
2018-11-20 09:38:15 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &key : sorted) {
|
Rewrite OpStats to use llvm formatting utilities.
Example Output:
Operations encountered:
-----------------------
addf , 11
constant , 4
return , 19
some_op , 1
tf.AvgPool , 3
tf.DepthwiseConv2dNative , 3
tf.FusedBatchNorm , 2
tfl.add , 7
tfl.average_pool_2d , 1
tfl.leaky_relu , 1
PiperOrigin-RevId: 229937190
2019-01-18 09:00:34 -08:00
|
|
|
std::tie(dialectName, opName) = splitOperationName(key);
|
|
|
|
|
|
|
|
|
|
// Left-align the names (aligning on the dialect) and right-align the count
|
|
|
|
|
// below. The alignment is for readability and does not affect CSV/FileCheck
|
|
|
|
|
// parsing.
|
|
|
|
|
if (dialectName.empty())
|
|
|
|
|
os.indent(maxLenDialect + 3);
|
|
|
|
|
else
|
|
|
|
|
os << llvm::right_justify(dialectName, maxLenDialect + 2) << '.';
|
|
|
|
|
|
|
|
|
|
// Left justify the operation name.
|
|
|
|
|
os << llvm::left_justify(opName, maxLenOpName) << " , " << opCount[key]
|
|
|
|
|
<< '\n';
|
2018-11-20 09:38:15 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2020-04-07 13:56:16 -07:00
|
|
|
std::unique_ptr<OperationPass<ModuleOp>> mlir::createPrintOpStatsPass() {
|
2020-02-12 09:03:40 +00:00
|
|
|
return std::make_unique<PrintOpStatsPass>();
|
|
|
|
|
}
|