[mlir][sparse] Remove the expansion of symmetric MTX in the sparse tensor storage.

We will support symmetric MTX without expanding the data in the sparse tensor
storage.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D144059
This commit is contained in:
bixia1
2023-02-16 09:52:01 -08:00
parent 5172877bbd
commit c2e248c6ae
11 changed files with 31 additions and 164 deletions

View File

@@ -27,7 +27,7 @@ class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
//===----------------------------------------------------------------------===//
def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
Arguments<(ins AnyType:$source, UnitAttr:$expandSymmetry)>,
Arguments<(ins AnyType:$source)>,
Results<(outs AnySparseTensor:$result)> {
string summary = "Materializes a new sparse tensor from given source";
string description = [{
@@ -40,12 +40,9 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
code. The operation is provided as an anchor that materializes a properly
typed sparse tensor with inital contents into a computation.
An optional attribute `expandSymmetry` can be used to extend this operation
to make symmetry in external formats explicit in the storage. That is, when
the attribute presents and a non-zero value is discovered at (i, j) where
i!=j, we add the same value to (j, i). This claims more storage than a pure
symmetric storage, and thus may cause a bad performance hit. True symmetric
storage is planned for the future.
Reading in a symmetric matrix will result in just the lower/upper triangular
part of the matrix (so that only relevant information is stored). Proper
symmetry support for operating on symmetric matrices is still TBD.
Example:
@@ -53,9 +50,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
sparse_tensor.new %source : !Source to tensor<1024x1024xf64, #CSR>
```
}];
let assemblyFormat = "(`expand_symmetry` $expandSymmetry^)? $source attr-dict"
"`:` type($source) `to` type($result)";
let hasVerifier = 1;
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
}
def SparseTensor_PackOp : SparseTensor_Op<"pack", [Pure]>,

View File

@@ -276,14 +276,14 @@ private:
}
/// The internal implementation of `readCOO`. We template over
/// `IsPattern` and `IsSymmetric` in order to perform LICM without
/// needing to duplicate the source code.
/// `IsPattern` in order to perform LICM without needing to duplicate the
/// source code.
//
// TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
// since that's what `readCOO` creates. Once we update `readCOO` to
// functionalize the mapping, then this helper will just take that
// same function.
template <typename V, bool IsPattern, bool IsSymmetric>
template <typename V, bool IsPattern>
void readCOOLoop(uint64_t lvlRank, detail::PermutationRef dim2lvl,
SparseTensorCOO<V> *lvlCOO);
@@ -323,21 +323,16 @@ SparseTensorCOO<V> *SparseTensorReader::readCOO(uint64_t lvlRank,
auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNNZ());
// Do some manual LICM, to avoid assertions in the for-loop.
const bool IsPattern = isPattern();
const bool IsSymmetric = (isSymmetric() && getRank() == 2);
if (IsPattern && IsSymmetric)
readCOOLoop<V, true, true>(lvlRank, d2l, lvlCOO);
else if (IsPattern)
readCOOLoop<V, true, false>(lvlRank, d2l, lvlCOO);
else if (IsSymmetric)
readCOOLoop<V, false, true>(lvlRank, d2l, lvlCOO);
if (IsPattern)
readCOOLoop<V, true>(lvlRank, d2l, lvlCOO);
else
readCOOLoop<V, false, false>(lvlRank, d2l, lvlCOO);
readCOOLoop<V, false>(lvlRank, d2l, lvlCOO);
// Close the file and return the COO.
closeFile();
return lvlCOO;
}
template <typename V, bool IsPattern, bool IsSymmetric>
template <typename V, bool IsPattern>
void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
detail::PermutationRef dim2lvl,
SparseTensorCOO<V> *lvlCOO) {
@@ -353,16 +348,6 @@ void SparseTensorReader::readCOOLoop(uint64_t lvlRank,
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
// TODO: <https://github.com/llvm/llvm-project/issues/54179>
lvlCOO->add(lvlInd, value);
// We currently chose to deal with symmetric matrices by fully
// constructing them. In the future, we may want to make symmetry
// implicit for storage reasons.
if constexpr (IsSymmetric)
if (dimInd[0] != dimInd[1]) {
// Must recompute `lvlInd`, since arbitrary maps don't preserve swap.
std::swap(dimInd[0], dimInd[1]);
dim2lvl.pushforward(dimRank, dimInd.data(), lvlInd.data());
lvlCOO->add(lvlInd, value);
}
}
}

View File

@@ -651,12 +651,6 @@ static LogicalResult verifySparsifierGetterSetter(
return success();
}
LogicalResult NewOp::verify() {
if (getExpandSymmetry() && getDimRank(getResult()) != 2)
return emitOpError("expand_symmetry can only be used for 2D tensors");
return success();
}
static LogicalResult verifyPackUnPack(Operation *op, TensorType cooTp,
TensorType dataTp, TensorType idxTp) {
if (!isUniqueCOOType(cooTp))

View File

@@ -1047,16 +1047,6 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
/*sizeHint=*/nnz, Attribute())
.getResult();
// The verifier ensures only 2D tensors can have the expandSymmetry flag.
Value symmetric;
if (dimRank == 2 && op.getExpandSymmetry()) {
symmetric =
createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric",
{rewriter.getI1Type()}, {reader}, EmitCInterface::Off)
.getResult(0);
} else {
symmetric = Value();
}
Type eltTp = dstTp.getElementType();
Value value = genAllocaScalar(rewriter, loc, eltTp);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1,
@@ -1077,21 +1067,6 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
Value v = rewriter.create<memref::LoadOp>(loc, value);
Value t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
indicesArray);
if (symmetric) {
Value eq = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]);
Value cond = rewriter.create<arith::AndIOp>(loc, symmetric, eq);
scf::IfOp ifOp =
rewriter.create<scf::IfOp>(loc, t.getType(), cond, /*else*/ true);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<scf::YieldOp>(
loc, Value(rewriter.create<InsertOp>(
loc, v, t, ValueRange{indicesArray[1], indicesArray[0]})));
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
rewriter.create<scf::YieldOp>(loc, t);
t = ifOp.getResult(0);
rewriter.setInsertionPointAfter(ifOp);
}
rewriter.create<scf::YieldOp>(loc, ArrayRef<Value>(t));
rewriter.setInsertionPointAfter(forOp);
// Link SSA chain.

View File

@@ -400,16 +400,6 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
func.func @sparse_new(%arg0: !llvm.ptr<i8>) {
// expected-error@+1 {{expand_symmetry can only be used for 2D tensors}}
%0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr<i8> to tensor<128xf64, #SparseVector>
return
}
// -----
func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
// expected-error@+1 {{unexpected type in convert}}
%0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>

View File

@@ -10,45 +10,6 @@
dimOrdering = affine_map<(i, j) -> (j, i)>
}>
// CHECK-LABEL: func.func @sparse_new_symmetry(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]])
// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref<?xindex>
// CHECK: call @copySparseTensorReaderDimSizes(%[[R]], %[[DS]])
// CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
// CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) size_hint=%[[N]]
// CHECK: %[[S:.*]] = call @getSparseTensorReaderIsSymmetric(%[[R]])
// CHECK: %[[VB:.*]] = memref.alloca()
// CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]])
// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
// CHECK: %[[V:.*]] = memref.load %[[VB]][]
// CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]]
// CHECK: %[[NE:.*]] = arith.cmpi ne, %[[E0]], %[[E1]]
// CHECK: %[[COND:.*]] = arith.andi %[[S]], %[[NE]]
// CHECK: %[[T3:.*]] = scf.if %[[COND]]
// CHECK: %[[T4:.*]] = sparse_tensor.insert %[[V]] into %[[T1]]{{\[}}%[[E1]], %[[E0]]]
// CHECK: scf.yield %[[T4]]
// CHECK: else
// CHECK: scf.yield %[[T1]]
// CHECK: scf.yield %[[T3]]
// CHECK: }
// CHECK: call @delSparseTensorReader(%[[R]])
// CHECK: %[[T5:.*]] = sparse_tensor.load %[[T2]] hasInserts
// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T5]]
// CHECK: bufferization.dealloc_tensor %[[T5]]
// CHECK: return %[[R]]
func.func @sparse_new_symmetry(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
%0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
return %0 : tensor<?x?xf32, #CSR>
}
// CHECK-LABEL: func.func @sparse_new(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>) -> tensor<?x?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>> {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index

View File

@@ -44,19 +44,6 @@ func.func @sparse_unpack(%sp : tensor<100xf64, #SparseVector>)
// -----
#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}>
// CHECK-LABEL: func @sparse_new_symmetry(
// CHECK-SAME: %[[A:.*]]: !llvm.ptr<i8>)
// CHECK: %[[T:.*]] = sparse_tensor.new expand_symmetry %[[A]] : !llvm.ptr<i8> to tensor<?x?xf64, #{{.*}}>
// CHECK: return %[[T]] : tensor<?x?xf64, #{{.*}}>
func.func @sparse_new_symmetry(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf64, #SparseMatrix> {
%0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr<i8> to tensor<?x?xf64, #SparseMatrix>
return %0 : tensor<?x?xf64, #SparseMatrix>
}
// -----
#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>
// CHECK-LABEL: func @sparse_dealloc(

View File

@@ -28,6 +28,9 @@
// REDEFINE: FileCheck %s
// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
// TODO: The test currently only operates on the triangular part of the
// symmetric matrix.
!Filename = !llvm.ptr<i8>
#SparseMatrix = #sparse_tensor.encoding<{
@@ -79,7 +82,7 @@ module {
// Read the sparse matrix from file, construct sparse storage.
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
%a = sparse_tensor.new expand_symmetry %fileName : !Filename to tensor<?x?xf64, #SparseMatrix>
%a = sparse_tensor.new %fileName : !Filename to tensor<?x?xf64, #SparseMatrix>
// Call the kernel.
%0 = call @kernel_sum_reduce(%a, %x)
@@ -87,7 +90,7 @@ module {
// Print the result for verification.
//
// CHECK: 30.2
// CHECK: 24.1
//
%v = tensor.extract %0[] : tensor<f64>
vector.print %v : f64

View File

@@ -28,6 +28,9 @@
// REDEFINE: FileCheck %s
// RUN: %{compile} | mlir-translate -mlir-to-llvmir | %{run}
// TODO: The test currently only operates on the triangular part of the
// symmetric matrix.
!Filename = !llvm.ptr<i8>
#SparseMatrix = #sparse_tensor.encoding<{
@@ -82,7 +85,7 @@ module {
// Read the sparse matrix from file, construct sparse storage.
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
%a = sparse_tensor.new expand_symmetry %fileName : !Filename to tensor<?x?xcomplex<f64>, #SparseMatrix>
%a = sparse_tensor.new %fileName : !Filename to tensor<?x?xcomplex<f64>, #SparseMatrix>
// Call the kernel.
%0 = call @kernel_sum_reduce(%a, %x)
@@ -90,8 +93,8 @@ module {
// Print the result for verification.
//
// CHECK: 30.2
// CHECK-NEXT: 22.2
// CHECK: 24.1
// CHECK-NEXT: 16.1
//
%v = tensor.extract %0[] : tensor<complex<f64>>
%real = complex.re %v : complex<f64>

View File

@@ -1,11 +1,14 @@
%%MatrixMarket matrix coordinate real symmetric
%%MatrixMarket matrix coordinate real general
%-------------------------------------------------------------------------------
% To download a matrix for a real world application
% https://math.nist.gov/MatrixMarket/
%-------------------------------------------------------------------------------
3 3 5
3 3 8
1 1 37423.0879671
1 2 -22.4050781162
1 3 -300.654980157
2 1 -22.4050781162
2 3 -.00869762944058
3 1 -300.654980157
3 2 -.00869762944058
3 3 805225.750212

View File

@@ -21,18 +21,12 @@ _DENSE = mlir_pytaco.ModeFormat.DENSE
_FORMAT = mlir_pytaco.Format([_COMPRESSED, _COMPRESSED])
_MTX_DATA_TEMPLATE = Template(
"""%%MatrixMarket matrix coordinate real $general_or_symmetry
_MTX_DATA = """%%MatrixMarket matrix coordinate real general
3 3 3
3 1 3
1 2 2
3 2 4
""")
def _get_mtx_data(value):
mtx_data = _MTX_DATA_TEMPLATE
return mtx_data.substitute(general_or_symmetry=value)
"""
# CHECK-LABEL: test_read_mtx_matrix_general
@@ -41,7 +35,7 @@ def test_read_mtx_matrix_general():
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.mtx")
with open(file_name, "w") as file:
file.write(_get_mtx_data("general"))
file.write(_MTX_DATA)
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
@@ -55,29 +49,6 @@ def test_read_mtx_matrix_general():
print(passed)
# CHECK-LABEL: test_read_mtx_matrix_symmetry
@testing_utils.run_test
def test_read_mtx_matrix_symmetry():
with tempfile.TemporaryDirectory() as test_dir:
file_name = os.path.join(test_dir, "data.mtx")
with open(file_name, "w") as file:
file.write(_get_mtx_data("symmetric"))
a = mlir_pytaco_io.read(file_name, _FORMAT)
passed = 0
# The value of a is stored as an MLIR sparse tensor.
passed += (not a.is_unpacked())
a.unpack()
passed += (a.is_unpacked())
coords, values = a.get_coordinates_and_values()
print(coords)
print(values)
passed += np.array_equal(coords,
[[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]])
passed += np.allclose(values, [2.0, 3.0, 2.0, 4.0, 3.0, 4.0])
# CHECK: 4
print(passed)
_TNS_DATA = """2 3
3 2
3 1 3