Verify CmpIOp's result type to be bool-like

This CL added two new traits, SameOperandsAndResultShape and
ResultsAreBoolLike, and changed CmpIOp to embody these two
traits. As a consequence, CmpIOp's result type now is verified
to be bool-like.

PiperOrigin-RevId: 223208438
This commit is contained in:
Lei Zhang
2018-11-28 11:49:26 -08:00
committed by jpienaar
parent 16f525bc27
commit 1f5330ac90
8 changed files with 108 additions and 18 deletions

View File

@@ -85,6 +85,7 @@ public:
OtherType getTFComplex128Type();
OtherType getTFF32REFType();
IntegerType getI1Type();
IntegerType getIntegerType(unsigned width);
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,

View File

@@ -283,7 +283,9 @@ bool verifyZeroResult(const Operation *op);
bool verifyOneResult(const Operation *op);
bool verifyNResults(const Operation *op, unsigned numOperands);
bool verifyAtLeastNResults(const Operation *op, unsigned numOperands);
bool verifySameOperandsAndResult(const Operation *op);
bool verifySameOperandsAndResultShape(const Operation *op);
bool verifySameOperandsAndResultType(const Operation *op);
bool verifyResultsAreBoolLike(const Operation *op);
bool verifyResultsAreFloatLike(const Operation *op);
bool verifyResultsAreIntegerLike(const Operation *op);
bool verifyIsTerminator(const Operation *op);
@@ -623,14 +625,40 @@ public:
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result shape: both are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsAndResultShape
: public TraitBase<ConcreteType, SameOperandsAndResultShape> {
public:
static bool verifyTrait(const Operation *op) {
return impl::verifySameOperandsAndResultShape(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
/// Note: this trait subsumes the SameOperandsAndResultShape trait.
/// Additionally, it requires all operands and results should also have
/// the same element type.
template <typename ConcreteType>
class SameOperandsAndResultType
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
public:
static bool verifyTrait(const Operation *op) {
return impl::verifySameOperandsAndResult(op);
return impl::verifySameOperandsAndResultType(op);
}
};
/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
public:
static bool verifyTrait(const Operation *op) {
return impl::verifyResultsAreBoolLike(op);
}
};

View File

@@ -220,9 +220,11 @@ enum class CmpIPredicate {
/// %r1 = cmpi "eq" %0, %1 : i32
/// %r2 = cmpi "slt" %0, %1 : tensor<42x42xi64>
/// %r3 = "cmpi"(%0, %1){predicate: 0} : (i8, i8) -> i1
class CmpIOp : public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
class CmpIOp
: public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
OpTrait::SameTypeOperands, OpTrait::NOperands<2>::Impl,
OpTrait::OneResult, OpTrait::ResultsAreBoolLike,
OpTrait::SameOperandsAndResultShape, OpTrait::HasNoSideEffect> {
public:
CmpIPredicate getPredicate() const {
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())

View File

@@ -84,6 +84,8 @@ OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
IntegerType Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}

View File

@@ -233,7 +233,7 @@ bool CondBranchOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType condInfo;
// Parse the condition.
Type int1Ty = parser->getBuilder().getIntegerType(1);
Type int1Ty = parser->getBuilder().getI1Type();
if (parser->parseOperand(condInfo) || parser->parseComma() ||
parser->resolveOperand(condInfo, int1Ty, result->operands)) {
return parser->emitError(parser->getNameLoc(),

View File

@@ -502,7 +502,51 @@ bool OpTrait::impl::verifyAtLeastNResults(const Operation *op,
return false;
}
bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
/// Returns false if the given two types have the same shape. That is,
/// they are both scalars, or they are both vectors / ranked tensors with
/// the same dimension specifications. The element type does not matter.
static bool verifyShapeMatch(Type type1, Type type2) {
// Check scalar cases
if (type1.isa<IntegerType>() || type1.isa<FloatType>() ||
type1.isa<IndexType>())
return !(type2.isa<IntegerType>() || type2.isa<FloatType>() ||
type2.isa<IndexType>());
// Check unranked tensor cases
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
return true;
// Check normal vector/tensor cases
if (auto vtType1 = type1.dyn_cast<VectorOrTensorType>()) {
auto vtType2 = type2.dyn_cast<VectorOrTensorType>();
return !(vtType2 && vtType1.getShape() == vtType2.getShape());
}
return false;
}
bool OpTrait::impl::verifySameOperandsAndResultShape(const Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return true;
auto type = op->getOperand(0)->getType();
for (unsigned i = 0, e = op->getNumResults(); i < e; ++i) {
if (verifyShapeMatch(op->getResult(i)->getType(), type))
return op->emitOpError(
"requires the same shape for all operands and results");
}
for (unsigned i = 1, e = op->getNumOperands(); i < e; ++i) {
if (verifyShapeMatch(op->getOperand(i)->getType(), type))
return op->emitOpError(
"requires the same shape for all operands and results");
}
return false;
}
bool OpTrait::impl::verifySameOperandsAndResultType(const Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return true;
auto type = op->getResult(0)->getType();
for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
if (op->getResult(i)->getType() != type)
@@ -574,6 +618,18 @@ bool OpTrait::impl::verifyIsTerminator(const Operation *op) {
return false;
}
bool OpTrait::impl::verifyResultsAreBoolLike(const Operation *op) {
for (auto *result : op->getResults()) {
auto elementType = getTensorOrVectorElementType(result->getType());
auto intType = elementType.dyn_cast<IntegerType>();
bool isBoolType = intType && intType.getWidth() == 1;
if (!isBoolType)
return op->emitOpError("requires a bool result type");
}
return false;
}
bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
for (auto *result : op->getResults()) {
if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())

View File

@@ -426,7 +426,7 @@ bool CallIndirectOp::verify() const {
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getI1SameShape(Builder *build, Type type) {
auto i1Type = build->getIntegerType(1);
auto i1Type = build->getI1Type();
if (type.isa<IntegerType>() || type.isa<FloatType>() || type.isa<IndexType>())
return i1Type;
if (auto tensorType = type.dyn_cast<RankedTensorType>())
@@ -532,8 +532,7 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
attrs[0].second = builder.getIntegerAttr(static_cast<int64_t>(predicate));
result->attributes = attrs;
// The result of comparison is formed from i1s in the same shape as type.
result->addTypes({getI1SameShape(&parser->getBuilder(), type)});
result->addTypes({getI1SameShape(&builder, type)});
return false;
}
@@ -568,12 +567,6 @@ bool CmpIOp::verify() const {
predicate >= (int64_t)CmpIPredicate::NumPredicates)
return emitOpError("'predicate' attribute value out of range");
if (getOperand(0)->getType() != getOperand(1)->getType())
return emitOpError("requires operands to have the same type");
if (checkI1SameShape(getOperand(0)->getType(), getResult()->getType()))
return emitOpError("result must have the same shape as inputs");
return false;
}

View File

@@ -215,6 +215,14 @@ bb0(%a : f32, %b : f32):
// -----
// Result type must be boolean like.
cfgfunc @cfgfunc_with_ops(i32, i32) {
bb0(%a : i32, %b : i32):
%r = "cmpi"(%a, %b) {predicate: 0} : (i32, i32) -> i32 // expected-error {{op requires a bool result type}}
}
// -----
cfgfunc @cfgfunc_with_ops(i32, i32) {
bb0(%a : i32, %b : i32):
// expected-error@+1 {{requires an integer attribute named 'predicate'}}
@@ -226,8 +234,8 @@ bb0(%a : i32, %b : i32):
cfgfunc @cfgfunc_with_ops() {
bb0:
%c = constant splat<vector<42 x i32>, 0> : vector<42 x i32>
// expected-error@+1 {{op result must have the same shape as inputs}}
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<42 x i32>
// expected-error@+1 {{op requires the same shape for all operands and results}}
%r = "cmpi"(%c, %c) {predicate: 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
}
// -----