Files
llvm/mlir/test/lib/Dialect/Shape/TestShapeFunctions.cpp
River Riddle fe7c0d90b2 [mlir][IR] Remove the concept of OperationProperties
These properties were useful for a few things before traits had a better integration story, but don't really carry their weight well these days. Most of these properties are already checked via traits in most of the code. It is better to align the system around traits, and improve the performance/cost of traits in general.

Differential Revision: https://reviews.llvm.org/D96088
2021-02-09 12:00:15 -08:00

90 lines
2.9 KiB
C++

//===- TestShapeFunctions.cpp - Passes to test shape function ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <queue>
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
/// This is a pass that reports shape functions associated with ops.
struct ReportShapeFnPass
: public PassWrapper<ReportShapeFnPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
};
} // end anonymous namespace
void ReportShapeFnPass::runOnOperation() {
auto module = getOperation();
// Report the shape function available to refine the op.
auto shapeFnId = Identifier::get("shape.function", &getContext());
auto remarkShapeFn = [&](shape::FunctionLibraryOp shapeFnLib, Operation *op) {
if (op->hasTrait<OpTrait::IsTerminator>())
return true;
if (auto typeInterface = dyn_cast<InferTypeOpInterface>(op)) {
op->emitRemark() << "implements InferType op interface";
return true;
}
if (auto fn = shapeFnLib.getShapeFunction(op)) {
op->emitRemark() << "associated shape function: " << fn.getName();
return true;
}
if (auto symbol = op->getAttrOfType<SymbolRefAttr>(shapeFnId)) {
auto fn = cast<FuncOp>(SymbolTable::lookupSymbolIn(module, symbol));
op->emitRemark() << "associated shape function: " << fn.getName();
return true;
}
return false;
};
// Lookup shape function library.
SmallVector<shape::FunctionLibraryOp, 4> libraries;
auto attr = module->getAttr("shape.lib");
if (attr) {
auto lookup = [&](Attribute attr) {
return cast<shape::FunctionLibraryOp>(
SymbolTable::lookupSymbolIn(module, attr.cast<SymbolRefAttr>()));
};
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
libraries.reserve(arrayAttr.size());
for (auto attr : arrayAttr)
libraries.push_back(lookup(attr));
} else {
libraries.reserve(1);
libraries.push_back(lookup(attr));
}
}
module.getBodyRegion().walk([&](FuncOp func) {
// Skip ops in the shape function library.
if (isa<shape::FunctionLibraryOp>(func->getParentOp()))
return;
func.walk([&](Operation *op) {
bool found = llvm::any_of(libraries, [&](shape::FunctionLibraryOp lib) {
return remarkShapeFn(lib, op);
});
if (!found)
op->emitRemark() << "no associated way to refine shape";
});
});
}
namespace mlir {
void registerShapeFunctionTestPasses() {
PassRegistration<ReportShapeFnPass>(
"test-shape-function-report",
"Test pass to report associated shape functions");
}
} // namespace mlir