mirror of
https://github.com/intel/llvm.git
synced 2026-01-27 06:06:34 +08:00
[mlir-tblgen] Avoid ODS verifier duplication
Different constraints may share the same predicate, in this case, we will generate duplicate ODS verification function. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D104369
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#define MLIR_TABLEGEN_PREDICATE_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -59,6 +60,8 @@ public:
|
||||
ArrayRef<llvm::SMLoc> getLoc() const;
|
||||
|
||||
protected:
|
||||
friend llvm::DenseMapInfo<Pred>;
|
||||
|
||||
// The TableGen definition of this predicate.
|
||||
const llvm::Record *def;
|
||||
};
|
||||
@@ -116,4 +119,18 @@ public:
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
namespace llvm {
|
||||
template <>
|
||||
struct DenseMapInfo<mlir::tblgen::Pred> {
|
||||
static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
|
||||
static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
|
||||
static unsigned getHashValue(mlir::tblgen::Pred pred) {
|
||||
return llvm::hash_value(pred.def);
|
||||
}
|
||||
static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
|
||||
return lhs == rhs;
|
||||
}
|
||||
};
|
||||
} // end namespace llvm
|
||||
|
||||
#endif // MLIR_TABLEGEN_PREDICATE_H_
|
||||
|
||||
@@ -13,19 +13,24 @@ def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
|
||||
|
||||
def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
|
||||
let arguments = (ins I32OrF32:$x);
|
||||
let results = (outs Variadic<I32OrF32>:$y);
|
||||
}
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
|
||||
// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
|
||||
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
|
||||
// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) {
|
||||
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
|
||||
|
||||
// Check there is no verifier with same predicate generated.
|
||||
// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) {
|
||||
// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
|
||||
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
|
||||
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
|
||||
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
|
||||
|
||||
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
|
||||
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
|
||||
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
|
||||
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
|
||||
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
|
||||
|
||||
// CHECK-LABEL: OpA::verify
|
||||
// CHECK: auto valueGroup0 = getODSOperands(0);
|
||||
|
||||
@@ -216,19 +216,50 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
|
||||
typeConstraints.insert(result.constraint.getAsOpaquePointer());
|
||||
}
|
||||
|
||||
// Record the mapping from predicate to constraint. If two constraints has the
|
||||
// same predicate and constraint summary, they can share the same verification
|
||||
// function.
|
||||
llvm::DenseMap<Pred, const void *> predToConstraint;
|
||||
FmtContext fctx;
|
||||
for (auto it : llvm::enumerate(typeConstraints)) {
|
||||
std::string name;
|
||||
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
|
||||
Pred pred = constraint.getPredicate();
|
||||
auto iter = predToConstraint.find(pred);
|
||||
if (iter != predToConstraint.end()) {
|
||||
do {
|
||||
Constraint built = Constraint::getFromOpaquePointer(iter->second);
|
||||
// We may have the different constraints but have the same predicate,
|
||||
// for example, ConstraintA and Variadic<ConstraintA>, note that
|
||||
// Variadic<> doesn't introduce new predicate. In this case, we can
|
||||
// share the same predicate function if they also have consistent
|
||||
// summary, otherwise we may report the wrong message while verification
|
||||
// fails.
|
||||
if (constraint.getSummary() == built.getSummary()) {
|
||||
name = getTypeConstraintFn(built).str();
|
||||
break;
|
||||
}
|
||||
++iter;
|
||||
} while (iter != predToConstraint.end() && iter->first == pred);
|
||||
}
|
||||
|
||||
if (!name.empty()) {
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generate an obscure and unique name for this type constraint.
|
||||
std::string name = (Twine("__mlir_ods_local_type_constraint_") +
|
||||
uniqueOutputLabel + Twine(it.index()))
|
||||
.str();
|
||||
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
|
||||
Twine(it.index()))
|
||||
.str();
|
||||
predToConstraint.insert(
|
||||
std::make_pair(constraint.getPredicate(), it.value()));
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
|
||||
// Only generate the methods if we are generating definitions.
|
||||
if (emitDecl)
|
||||
continue;
|
||||
|
||||
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
|
||||
os << "static ::mlir::LogicalResult " << name
|
||||
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
|
||||
"valueKind, unsigned valueGroupStartIndex) {\n";
|
||||
|
||||
Reference in New Issue
Block a user