From 62896428a766ecdec28432d3bfea3155cbeafe78 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 14 Dec 2022 12:58:50 -0800 Subject: [PATCH] [mlir][sparse] Moving/renaming genBuffer to allocaBuffer This allows allocaBuffer to be used outside of SparseTensorConversion.cpp, which will be helpful for a some future commits. Reviewed By: aartbik, Peiming Differential Revision: https://reviews.llvm.org/D140047 --- .../SparseTensor/Transforms/CodegenUtils.cpp | 12 +++++++ .../SparseTensor/Transforms/CodegenUtils.h | 5 +++ .../Transforms/SparseTensorConversion.cpp | 32 ++++++------------- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp index ede9c56b7b70..b5c82bd6db79 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -1151,6 +1151,18 @@ Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc, return builder.create(loc, MemRefType::get({}, tp)); } +Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc, + ValueRange values) { + const unsigned sz = values.size(); + assert(sz >= 1); + Value buffer = genAlloca(builder, loc, sz, values[0].getType()); + for (unsigned i = 0; i < sz; i++) { + Value idx = constantIndex(builder, loc, i); + builder.create(loc, values[i], buffer, idx); + } + return buffer; +} + Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc, RankedTensorType tensorTp, ValueRange sizes) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h index 4d5805e6de93..a121522d0190 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -136,6 +136,11 @@ Value genAlloca(OpBuilder &builder, Location loc, unsigned sz, Type tp); /// of the given type, and returns the `memref<$tp>`. Value genAllocaScalar(OpBuilder &builder, Location loc, Type tp); +/// Generates a temporary buffer, initializes it with the given contents, +/// and returns it as type `memref` (rather than specifying the +/// size of the buffer). +Value allocaBuffer(OpBuilder &builder, Location loc, ValueRange values); + /// Generates code to allocate a buffer of the given type, and zero /// initialize it. If the buffer type has any dynamic sizes, then the /// `sizes` parameter should be as filled by sizesFromPtr(); that way diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 2603b8d6b1ef..4f4dbe49926f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -207,18 +207,6 @@ static Value genAlloc(RewriterBase &rewriter, Location loc, Value sz, Type tp) { return rewriter.create(loc, memTp, ValueRange{sz}); } -/// Generates a temporary buffer of the given type and given contents. -static Value genBuffer(OpBuilder &builder, Location loc, ValueRange values) { - unsigned sz = values.size(); - assert(sz >= 1); - Value buffer = genAlloca(builder, loc, sz, values[0].getType()); - for (unsigned i = 0; i < sz; i++) { - Value idx = constantIndex(builder, loc, i); - builder.create(loc, values[i], buffer, idx); - } - return buffer; -} - /// Generates a temporary buffer for the level-types of the given encoding. static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, SparseTensorEncodingAttr enc) { @@ -227,7 +215,7 @@ static Value genLvlTypesBuffer(OpBuilder &builder, Location loc, lvlTypes.reserve(dlts.size()); for (auto dlt : dlts) lvlTypes.push_back(constantDimLevelTypeEncoding(builder, loc, dlt)); - return genBuffer(builder, loc, lvlTypes); + return allocaBuffer(builder, loc, lvlTypes); } /// This class abstracts over the API of `_mlir_ciface_newSparseTensor`: @@ -329,7 +317,7 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, // Dimension-sizes array of the enveloping tensor. Useful for either // verification of external data, or for construction of internal data. assert(dimSizes.size() == dimRank && "Dimension-rank mismatch"); - params[kParamDimSizes] = genBuffer(builder, loc, dimSizes); + params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizes); // The level-sizes array must be passed as well, since for arbitrary // dim2lvl mappings it cannot be trivially reconstructed at runtime. // For now however, since we're still assuming permutations, we will @@ -358,10 +346,10 @@ NewCallParams &NewCallParams::genBuffers(SparseTensorEncodingAttr enc, lvlSizes[i] = dimSizes[i]; } } - params[kParamLvlSizes] = genBuffer(builder, loc, lvlSizes); - params[kParamLvl2Dim] = genBuffer(builder, loc, lvl2dim); + params[kParamLvlSizes] = allocaBuffer(builder, loc, lvlSizes); + params[kParamLvl2Dim] = allocaBuffer(builder, loc, lvl2dim); params[kParamDim2Lvl] = - dimOrder ? genBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; + dimOrder ? allocaBuffer(builder, loc, dim2lvl) : params[kParamLvl2Dim]; // Secondary and primary types encoding. setTemplateTypes(enc, stp); // Finally, make note that initialization is complete. @@ -780,7 +768,7 @@ public: // Construct the dimShape. const auto dimShape = stp.getShape(); SmallVector dimShapeValues = getDimShape(rewriter, loc, stp); - Value dimShapeBuffer = genBuffer(rewriter, loc, dimShapeValues); + Value dimShapeBuffer = allocaBuffer(rewriter, loc, dimShapeValues); // Allocate `SparseTensorReader` and perform all initial setup that // does not depend on lvlSizes (nor dim2lvl, lvl2dim, etc). Type opaqueTp = getOpaquePointerType(rewriter); @@ -833,9 +821,9 @@ public: ? rewriter.create(loc, dimSizesBuffer, dim) : dimShapeValues[d]; } - lvlSizesBuffer = genBuffer(rewriter, loc, lvlSizeValues); - lvl2dimBuffer = genBuffer(rewriter, loc, lvl2dimValues); - dim2lvlBuffer = genBuffer(rewriter, loc, dim2lvlValues); + lvlSizesBuffer = allocaBuffer(rewriter, loc, lvlSizeValues); + lvl2dimBuffer = allocaBuffer(rewriter, loc, lvl2dimValues); + dim2lvlBuffer = allocaBuffer(rewriter, loc, dim2lvlValues); } else { assert(dimRank == lvlRank && "Rank mismatch"); SmallVector iotaValues; @@ -843,7 +831,7 @@ public: for (unsigned i = 0; i < lvlRank; i++) iotaValues.push_back(constantIndex(rewriter, loc, i)); lvlSizesBuffer = dimSizesBuffer ? dimSizesBuffer : dimShapeBuffer; - dim2lvlBuffer = lvl2dimBuffer = genBuffer(rewriter, loc, iotaValues); + dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(rewriter, loc, iotaValues); } // Use the `reader` to parse the file. SmallVector params{