mirror of
https://github.com/intel/llvm.git
synced 2026-02-02 02:00:03 +08:00
[mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (#69540)
This makes sure - GEN MAP dim=2 lvl=4 (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2) -- (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3) is indeed encoded as MAP-REF (dim=2, lvl=4) isperm=0 d2l = [ d0/2 d1/2 d0%2 d1%2 ] ld2 = [ l2+2*l0 l3+2*l1 ]
This commit is contained in:
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
|
||||
// This code deals with permutations as well as non-permutations that
|
||||
// arise from rank changing blocking.
|
||||
const auto dimToLvl = stt.getDimToLvl();
|
||||
const auto lvlToDim = stt.getLvlToDim();
|
||||
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
|
||||
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
|
||||
SmallVector<Value> lvlSizesValues(lvlRank);
|
||||
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
|
||||
Dimension d = 0;
|
||||
uint64_t cf = 0, cm = 0;
|
||||
switch (exp.getKind()) {
|
||||
case AffineExprKind::DimId:
|
||||
case AffineExprKind::DimId: {
|
||||
d = exp.cast<AffineDimExpr>().getPosition();
|
||||
break;
|
||||
case AffineExprKind::FloorDiv:
|
||||
d = exp.cast<AffineBinaryOpExpr>()
|
||||
.getLHS()
|
||||
.cast<AffineDimExpr>()
|
||||
.getPosition();
|
||||
cf = exp.cast<AffineBinaryOpExpr>()
|
||||
.getRHS()
|
||||
.cast<AffineConstantExpr>()
|
||||
.getValue();
|
||||
}
|
||||
case AffineExprKind::FloorDiv: {
|
||||
auto floor = exp.cast<AffineBinaryOpExpr>();
|
||||
d = floor.getLHS().cast<AffineDimExpr>().getPosition();
|
||||
cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
|
||||
break;
|
||||
case AffineExprKind::Mod:
|
||||
d = exp.cast<AffineBinaryOpExpr>()
|
||||
.getLHS()
|
||||
.cast<AffineDimExpr>()
|
||||
.getPosition();
|
||||
cm = exp.cast<AffineBinaryOpExpr>()
|
||||
.getRHS()
|
||||
.cast<AffineConstantExpr>()
|
||||
.getValue();
|
||||
}
|
||||
case AffineExprKind::Mod: {
|
||||
auto mod = exp.cast<AffineBinaryOpExpr>();
|
||||
d = mod.getLHS().cast<AffineDimExpr>().getPosition();
|
||||
cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
|
||||
}
|
||||
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
|
||||
lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
|
||||
// Compute the level sizes.
|
||||
// (1) l = d : size(d)
|
||||
// (2) l = d / c : size(d) / c
|
||||
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
|
||||
}
|
||||
lvlSizesValues[l] = lvlSz;
|
||||
}
|
||||
// Generate lvl2dim.
|
||||
assert(dimRank == lvlToDim.getNumResults());
|
||||
for (Dimension d = 0; d < dimRank; d++) {
|
||||
AffineExpr exp = lvlToDim.getResult(d);
|
||||
// We expect:
|
||||
// (1) d = l
|
||||
// (2) d = l' * c + l
|
||||
Level l = 0, ll = 0;
|
||||
uint64_t c = 0;
|
||||
switch (exp.getKind()) {
|
||||
case AffineExprKind::DimId: {
|
||||
l = exp.cast<AffineDimExpr>().getPosition();
|
||||
break;
|
||||
}
|
||||
case AffineExprKind::Add: {
|
||||
// Always mul on lhs, symbol/constant on rhs.
|
||||
auto add = exp.cast<AffineBinaryOpExpr>();
|
||||
assert(add.getLHS().getKind() == AffineExprKind::Mul);
|
||||
auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
|
||||
ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
|
||||
c = mul.getRHS().cast<AffineConstantExpr>().getValue();
|
||||
l = add.getRHS().cast<AffineDimExpr>().getPosition();
|
||||
break;
|
||||
}
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
|
||||
}
|
||||
lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
|
||||
}
|
||||
// Return buffers.
|
||||
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
|
||||
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
|
||||
|
||||
Reference in New Issue
Block a user