Introduce pretty syntax for shape_cast as discussed on the list last week.

PiperOrigin-RevId: 212823681
This commit is contained in:
Chris Lattner
2018-09-13 09:16:32 -07:00
committed by jpienaar
parent a7611790f8
commit a21f2f453d
5 changed files with 48 additions and 4 deletions

View File

@@ -186,6 +186,9 @@ public:
/// Parse a colon followed by a type list, which must have at least one type.
virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0;
/// Parse a keyword followed by a type.
virtual bool parseKeywordType(const char *keyword, Type *&result) = 0;
/// Add the specified type to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.

View File

@@ -427,6 +427,10 @@ class ReturnOp
public:
static StringRef getOperationName() { return "return"; }
// TODO When there is a client.
// static void build(Builder *builder, OperationState *result,
// ArrayRef<SSAValue *> results);
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@@ -461,10 +465,8 @@ public:
}
// Hooks to customize behavior of this op.
// TODO(clattner): Add parse/print hooks when we agree about the concrete
// syntax.
// static bool parse(OpAsmParser *parser, OperationState *result);
// void print(OpAsmPrinter *p) const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
private:

View File

@@ -754,6 +754,20 @@ bool ShapeCastOp::verify() const {
return false;
}
void ShapeCastOp::print(OpAsmPrinter *p) const {
*p << "shape_cast " << *getOperand() << " : " << *getOperand()->getType()
<< " to " << *getType();
}
bool ShapeCastOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType srcInfo;
Type *srcType, *dstType;
return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
parser->resolveOperand(srcInfo, srcType, result->operands) ||
parser->parseKeywordType("to", dstType) ||
parser->addTypeToList(dstType, result->types);
}
//===----------------------------------------------------------------------===//
// StoreOp
//===----------------------------------------------------------------------===//

View File

@@ -1733,6 +1733,14 @@ public:
return false;
}
/// Parse a keyword followed by a type.
bool parseKeywordType(const char *keyword, Type *&result) override {
if (parser.getTokenSpelling() != keyword)
return parser.emitError("expected '" + Twine(keyword) + "'");
parser.consumeToken();
return !(result = parser.parseType());
}
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.

View File

@@ -145,3 +145,20 @@ mlfunc @extract_element(%arg0 : tensor<??i32>, %arg1 : tensor<4x4xf32>) -> i32 {
return %0 : i32
}
// CHECK-LABEL: mlfunc @shape_cast(%arg0
mlfunc @shape_cast(%arg0 : tensor<??f32>, %arg1 : tensor<4x4xf32>, %arg2 : tensor<?x?xf32>) {
// CHECK: %0 = shape_cast %arg0 : tensor<??f32> to tensor<?x?xf32>
%0 = shape_cast %arg0 : tensor<??f32> to tensor<?x?xf32>
// CHECK: %1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<??f32>
%1 = shape_cast %arg1 : tensor<4x4xf32> to tensor<??f32>
// CHECK: %2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
%2 = shape_cast %arg2 : tensor<?x?xf32> to tensor<4x?xf32>
// CHECK: %3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
%3 = shape_cast %2 : tensor<4x?xf32> to tensor<?x?xf32>
return
}