[mlir] Add shape.with_shape op

This is an operation that can returns a new ValueShape with a different shape. Useful for composing shape function calls and reusing existing shape transfer functions.

Just adding the op in this change.

Differential Revision: https://reviews.llvm.org/D84217
This commit is contained in:
Jacques Pienaar
2020-07-31 14:46:48 -07:00
parent 2a6c8b2e95
commit 86a78546b9
3 changed files with 62 additions and 2 deletions

View File

@@ -100,7 +100,11 @@ def Shape_ValueShapeType : DialectType<ShapeDialect,
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
tuple of a value (potentially unknown) and `shape.type`. The value and shape
can either or both be unknown. If both the `value` and `shape` are known,
then the shape of `value` is conformant with `shape`.
then the shape of `value` is conformant with `shape`. That is, the shape of
the value conforms to the shape of the ValueShape, so that if we have
ValueShape `(value, shape)` then `join(shape_of(value), shape)` would be
error free and in particular it means that if both are statically known,
then they are equal.
}];
}

View File

@@ -432,6 +432,49 @@ def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
let summary = "Returns ValueShape with given shape";
let description = [{
Returns ValueShape with the shape updated to match the shape operand. That
is a new ValueShape tuple is created with value equal to `operand`'s
value and shape equal to `shape`. If the ValueShape and given `shape` are
non-conformant, then the returned ValueShape will represent an error of
this mismatch. Similarly if either inputs are in an error state, then an
error is popagated.
Usage:
%0 = shape.with_shape %1, %2 : tensor<...>, !shape.shape
This is used, for example, where one combines shape function calculations
and/or call one shape function from another. E.g.,
```mlir
func @shape_foobah(%a: !shape.value_shape,
%b: !shape.value_shape,
%c: !shape.value_shape) -> !shape.shape {
%0 = call @shape_foo(%a, %b) :
(!shape.value_shape, !shape.value_shape) -> !shape.shape
%1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
%2 = call @shape_bah(%c, %1) :
(!shape.value_shape, !shape.value_shape) -> !shape.shape
return %2 : !shape.shape
}
```
This op need not be a refinement of the shape. In non-error cases the input
ValueShape's value and shape are conformant and so too for the output, but
the result may be less specified than `operand`'s shape as `shape` is
merely used to construct the new ValueShape. If join behavior is desired
then a join op should be used.
}];
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
Shape_ShapeType:$shape);
let results = (outs Shape_ValueShapeType:$result);
let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
}
def Shape_YieldOp : Shape_Op<"yield",
[HasParent<"ReduceOp">,
NoSideEffect,

View File

@@ -221,4 +221,17 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
return %result : !shape.size
}
// Testing nvoking shape function from another. shape_equal_shapes is merely
// a trivial helper function to invoke elsewhere.
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
return %2 : !shape.shape
}
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.with_shape %b, %0 : !shape.value_shape, !shape.shape
%2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape
return %2 : !shape.shape
}