[mlir][Intrange] Fix materializing ShapedType constant values (#158359)

When materializing integer ranges of splat tensors or vector as
constants, they should use DenseElementsAttr of the shaped type, not
IntegerAttrs of the element types, since this can violate the invariants
of tensor/vector ops.

Co-authored-by: Jeff Niu <jeffniu@openai.com>
This commit is contained in:
Jeff Niu
2025-09-12 13:53:32 -07:00
committed by GitHub
parent b5516dad6e
commit 86bcd1c2b2
3 changed files with 30 additions and 3 deletions

View File

@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
Type type = getElementTypeOrSelf(value);
solver->propagateIfChanged(
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
Attribute cstAttr;
if (isa<IntegerType, IndexType>(value.getType())) {
cstAttr = IntegerAttr::get(value.getType(), *constant);
} else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
cstAttr = SplatElementsAttr::get(shapedTy, *constant);
} else {
llvm::report_fatal_error(
Twine("FIXME: Don't know how to create a constant for this type: ") +
mlir::debugString(value.getType()));
}
solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(

View File

@@ -8,6 +8,7 @@
#include <utility>
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();

View File

@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
// -----
// CHECK-LABEL: @analysis_crash
func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<-1> : tensor<128xi32>
%splat = tensor.splat %arg0 : tensor<128xi32>
%0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
scf.yield %arg3 : tensor<128xi32>
}
%1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
// Make sure the analysis doesn't crash when materializing the range as a tensor constant.
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
return %2 : tensor<128xi64>
}