mirror of
https://github.com/intel/llvm.git
synced 2026-02-07 16:11:27 +08:00
[mlir][Intrange] Fix materializing ShapedType constant values (#158359)
When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops. Co-authored-by: Jeff Niu <jeffniu@openai.com>
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/LoopLikeInterface.h"
|
||||
#include "mlir/Support/DebugStringHelper.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
|
||||
else
|
||||
dialect = value.getParentBlock()->getParentOp()->getDialect();
|
||||
|
||||
Type type = getElementTypeOrSelf(value);
|
||||
solver->propagateIfChanged(
|
||||
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
|
||||
Attribute cstAttr;
|
||||
if (isa<IntegerType, IndexType>(value.getType())) {
|
||||
cstAttr = IntegerAttr::get(value.getType(), *constant);
|
||||
} else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
|
||||
cstAttr = SplatElementsAttr::get(shapedTy, *constant);
|
||||
} else {
|
||||
llvm::report_fatal_error(
|
||||
Twine("FIXME: Don't know how to create a constant for this type: ") +
|
||||
mlir::debugString(value.getType()));
|
||||
}
|
||||
solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
|
||||
}
|
||||
|
||||
LogicalResult IntegerRangeAnalysis::visitOperation(
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
||||
|
||||
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
|
||||
MLIRContext *ctx = op->getContext();
|
||||
DataFlowSolver solver;
|
||||
solver.load<DeadCodeAnalysis>();
|
||||
solver.load<SparseConstantPropagation>();
|
||||
solver.load<IntegerRangeAnalysis>();
|
||||
if (failed(solver.initializeAndRun(op)))
|
||||
return signalPassFailure();
|
||||
|
||||
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
|
||||
%mod = arith.remsi %val, %c64 : i8
|
||||
return %mod : i8
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @analysis_crash
|
||||
func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%cst = arith.constant dense<-1> : tensor<128xi32>
|
||||
%splat = tensor.splat %arg0 : tensor<128xi32>
|
||||
%0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
|
||||
scf.yield %arg3 : tensor<128xi32>
|
||||
}
|
||||
%1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
|
||||
// Make sure the analysis doesn't crash when materializing the range as a tensor constant.
|
||||
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
|
||||
return %2 : tensor<128xi64>
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user