mirror of
https://github.com/intel/llvm.git
synced 2026-01-25 10:55:58 +08:00
[MLIR][Shape] Fix lowering of shape.get_extent
The declarative conversion patterns caused crashes in the asan configuration. The non-declarative implementation circumvents this. Differential Revision: https://reviews.llvm.org/D82797
This commit is contained in:
@@ -90,6 +90,29 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class GetExtentOpConverter : public OpConversionPattern<GetExtentOp> {
|
||||
using OpConversionPattern<GetExtentOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(GetExtentOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
GetExtentOp::Adaptor transformed(operands);
|
||||
|
||||
// Derive shape extent directly from shape origin if possible.
|
||||
// This circumvents the necessity to materialize the shape in memory.
|
||||
if (auto shapeOfOp = op.shape().getDefiningOp<ShapeOfOp>()) {
|
||||
rewriter.replaceOpWithNewOp<DimOp>(op, shapeOfOp.arg(),
|
||||
transformed.dim());
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractElementOp>(
|
||||
op, rewriter.getIndexType(), transformed.shape(),
|
||||
ValueRange{transformed.dim()});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RankOpConverter : public OpConversionPattern<shape::RankOp> {
|
||||
public:
|
||||
using OpConversionPattern<shape::RankOp>::OpConversionPattern;
|
||||
@@ -161,6 +184,7 @@ void mlir::populateShapeToStandardConversionPatterns(
|
||||
BinaryOpConversion<AddOp, AddIOp>,
|
||||
BinaryOpConversion<MulOp, MulIOp>,
|
||||
ConstSizeOpConverter,
|
||||
GetExtentOpConverter,
|
||||
RankOpConverter,
|
||||
ShapeOfOpConversion>(ctx);
|
||||
// clang-format on
|
||||
|
||||
@@ -19,20 +19,3 @@ def SizeToIndexOpConversion : Pat<
|
||||
(Shape_SizeToIndexOp $arg),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
// Derive shape extent directly from shape origin if possible.
|
||||
// This circumvents the necessity to materialize the shape in memory.
|
||||
def GetExtentShapeOfConversion : Pat<
|
||||
(Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx),
|
||||
(Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx))),
|
||||
[],
|
||||
(addBenefit 10)>;
|
||||
def GetExtentFromExtentTensorConversion : Pattern<
|
||||
(Shape_GetExtentOp (Shape_FromExtentTensorOp $extents), $idx),
|
||||
[
|
||||
(Shape_SizeToIndexOp:$std_idx $idx),
|
||||
(ExtractElementOp:$std_result $extents, (NativeCodeCall<"ValueRange({$0})"> $std_idx)),
|
||||
(Shape_IndexToSizeOp $std_result)
|
||||
],
|
||||
[],
|
||||
(addBenefit 10)>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user