mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
Verify CmpIOp's result type to be bool-like
This CL added two new traits, SameOperandsAndResultShape and ResultsAreBoolLike, and changed CmpIOp to embody these two traits. As a consequence, CmpIOp's result type now is verified to be bool-like. PiperOrigin-RevId: 223208438
This commit is contained in:
@@ -85,6 +85,7 @@ public:
|
||||
OtherType getTFComplex128Type();
|
||||
OtherType getTFF32REFType();
|
||||
|
||||
IntegerType getI1Type();
|
||||
IntegerType getIntegerType(unsigned width);
|
||||
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
|
||||
MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
|
||||
|
||||
@@ -283,7 +283,9 @@ bool verifyZeroResult(const Operation *op);
|
||||
bool verifyOneResult(const Operation *op);
|
||||
bool verifyNResults(const Operation *op, unsigned numOperands);
|
||||
bool verifyAtLeastNResults(const Operation *op, unsigned numOperands);
|
||||
bool verifySameOperandsAndResult(const Operation *op);
|
||||
bool verifySameOperandsAndResultShape(const Operation *op);
|
||||
bool verifySameOperandsAndResultType(const Operation *op);
|
||||
bool verifyResultsAreBoolLike(const Operation *op);
|
||||
bool verifyResultsAreFloatLike(const Operation *op);
|
||||
bool verifyResultsAreIntegerLike(const Operation *op);
|
||||
bool verifyIsTerminator(const Operation *op);
|
||||
@@ -623,14 +625,40 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result shape: both are scalars, vectors/tensors of the same
|
||||
/// shape.
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultShape
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultShape> {
|
||||
public:
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifySameOperandsAndResultShape(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result type.
|
||||
///
|
||||
/// Note: this trait subsumes the SameOperandsAndResultShape trait.
|
||||
/// Additionally, it requires all operands and results should also have
|
||||
/// the same element type.
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultType
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
|
||||
public:
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifySameOperandsAndResult(op);
|
||||
return impl::verifySameOperandsAndResultType(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// This class verifies that any results of the specified op have a boolean
|
||||
/// type, a vector thereof, or a tensor thereof.
|
||||
template <typename ConcreteType>
|
||||
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
|
||||
public:
|
||||
static bool verifyTrait(const Operation *op) {
|
||||
return impl::verifyResultsAreBoolLike(op);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -220,9 +220,11 @@ enum class CmpIPredicate {
|
||||
/// %r1 = cmpi "eq" %0, %1 : i32
|
||||
/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
|
||||
/// %r3 = "cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
|
||||
class CmpIOp : public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
|
||||
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
|
||||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
class CmpIOp
|
||||
: public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
|
||||
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
|
||||
OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
|
||||
OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
CmpIPredicate getPredicate() const {
|
||||
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
|
||||
|
||||
@@ -84,6 +84,8 @@ OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
||||
|
||||
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
|
||||
|
||||
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
|
||||
|
||||
IntegerType Builder::getIntegerType(unsigned width) {
|
||||
return Type::getInteger(width, context);
|
||||
}
|
||||
|
||||
@@ -233,7 +233,7 @@ bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType condInfo;
|
||||
|
||||
// Parse the condition.
|
||||
Type int1Ty = parser->getBuilder().getIntegerType(1);
|
||||
Type int1Ty = parser->getBuilder().getI1Type();
|
||||
if (parser->parseOperand(condInfo) || parser->parseComma() ||
|
||||
parser->resolveOperand(condInfo, int1Ty, result->operands)) {
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
|
||||
@@ -502,7 +502,51 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
|
||||
/// Returns false if the given two types have the same shape. That is,
|
||||
/// they are both scalars, or they are both vectors / ranked tensors with
|
||||
/// the same dimension specifications. The element type does not matter.
|
||||
static bool verifyShapeMatch(Type type1, Type type2) {
|
||||
// Check scalar cases
|
||||
if (type1.isa<IntegerType>() || type1.isa<FloatType>() ||
|
||||
type1.isa<IndexType>())
|
||||
return !(type2.isa<IntegerType>() || type2.isa<FloatType>() ||
|
||||
type2.isa<IndexType>());
|
||||
|
||||
// Check unranked tensor cases
|
||||
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
|
||||
return true;
|
||||
|
||||
// Check normal vector/tensor cases
|
||||
if (auto vtType1 = type1.dyn_cast<VectorOrTensorType>()) {
|
||||
auto vtType2 = type2.dyn_cast<VectorOrTensorType>();
|
||||
return !(vtType2 && vtType1.getShape() == vtType2.getShape());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) {
|
||||
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
|
||||
return true;
|
||||
|
||||
auto type = op->getOperand(0)->getType();
|
||||
for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
|
||||
if (verifyShapeMatch(op->getResult(i)->getType(), type))
|
||||
return op->emitOpError(
|
||||
"requires the same shape for all operands and results");
|
||||
}
|
||||
for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
|
||||
if (verifyShapeMatch(op->getOperand(i)->getType(), type))
|
||||
return op->emitOpError(
|
||||
"requires the same shape for all operands and results");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) {
|
||||
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
|
||||
return true;
|
||||
|
||||
auto type = op->getResult(0)->getType();
|
||||
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
|
||||
if (op->getResult(i)->getType() != type)
|
||||
@@ -574,6 +618,18 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
auto elementType = getTensorOrVectorElementType(result->getType());
|
||||
auto intType = elementType.dyn_cast<IntegerType>();
|
||||
bool isBoolType = intType && intType.getWidth() == 1;
|
||||
if (!isBoolType)
|
||||
return op->emitOpError("requires a bool result type");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
|
||||
|
||||
@@ -426,7 +426,7 @@ bool CallIndirectOp::verify() const {
|
||||
|
||||
// Return the type of the same shape (scalar, vector or tensor) containing i1.
|
||||
static Type getI1SameShape(Builder *build, Type type) {
|
||||
auto i1Type = build->getIntegerType(1);
|
||||
auto i1Type = build->getI1Type();
|
||||
if (type.isa<IntegerType>() || type.isa<FloatType>() || type.isa<IndexType>())
|
||||
return i1Type;
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
||||
@@ -532,8 +532,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
attrs[0].second = builder.getIntegerAttr(static_cast<int64_t>(predicate));
|
||||
result->attributes = attrs;
|
||||
|
||||
// The result of comparison is formed from i1s in the same shape as type.
|
||||
result->addTypes({getI1SameShape(&parser->getBuilder(), type)});
|
||||
result->addTypes({getI1SameShape(&builder, type)});
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -568,12 +567,6 @@ bool CmpIOp::verify() const {
|
||||
predicate >= (int64_t)CmpIPredicate::NumPredicates)
|
||||
return emitOpError("'predicate' attribute value out of range");
|
||||
|
||||
if (getOperand(0)->getType() != getOperand(1)->getType())
|
||||
return emitOpError("requires operands to have the same type");
|
||||
|
||||
if (checkI1SameShape(getOperand(0)->getType(), getResult()->getType()))
|
||||
return emitOpError("result must have the same shape as inputs");
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -215,6 +215,14 @@ bb0(%a : f32, %b : f32):
|
||||
|
||||
// -----
|
||||
|
||||
// Result type must be boolean like.
|
||||
cfgfunc @cfgfunc_with_ops(i32, i32) {
|
||||
bb0(%a : i32, %b : i32):
|
||||
%r = "cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op requires a bool result type}}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
cfgfunc @cfgfunc_with_ops(i32, i32) {
|
||||
bb0(%a : i32, %b : i32):
|
||||
// expected-error@+1 {{requires an integer attribute named 'predicate'}}
|
||||
@@ -226,8 +234,8 @@ bb0(%a : i32, %b : i32):
|
||||
cfgfunc @cfgfunc_with_ops() {
|
||||
bb0:
|
||||
%c = constant splat<vector<42 x i32>, 0> : vector<42 x i32>
|
||||
// expected-error@+1 {{op result must have the same shape as inputs}}
|
||||
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
|
||||
// expected-error@+1 {{op requires the same shape for all operands and results}}
|
||||
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Reference in New Issue
Block a user