[mlir][sparse] replace ad-hoc MemRef struct with CRunnerUtils definition

This revision removes the ad-hoc MemRefs that were needed using the old
ABI (when we still passed by value) and replaces them with the shared
StridedMemRef definitions of CRunnerUtils (possible now that we pass by
pointer). This avoids code duplication and makes sure we have a consistent
view of strided memory references in all our support libraries.

Reviewed By: jsetoain

Differential Revision: https://reviews.llvm.org/D110221
This commit is contained in:
Aart Bik
2021-09-21 22:56:00 -07:00
parent b33a1cc05b
commit 56bddf3b1c

View File

@@ -502,32 +502,11 @@ char *getTensorFilename(uint64_t id) {
// with sparse tensors, which are only visible as opaque pointers externally.
// These methods should be used exclusively by MLIR compiler-generated code.
//
// Because we cannot use C++ templates with C linkage, some macro magic is used
// to generate implementations for all required type combinations that can be
// called from MLIR compiler-generated code.
// Some macro magic is used to generate implementations for all required type
// combinations that can be called from MLIR compiler-generated code.
//
//===----------------------------------------------------------------------===//
#define TEMPLATE(NAME, TYPE) \
struct NAME { \
const TYPE *base; \
const TYPE *data; \
uint64_t off; \
uint64_t sizes[1]; \
uint64_t strides[1]; \
}
TEMPLATE(MemRef1DU64, uint64_t);
TEMPLATE(MemRef1DU32, uint32_t);
TEMPLATE(MemRef1DU16, uint16_t);
TEMPLATE(MemRef1DU8, uint8_t);
TEMPLATE(MemRef1DI64, int64_t);
TEMPLATE(MemRef1DI32, int32_t);
TEMPLATE(MemRef1DI16, int16_t);
TEMPLATE(MemRef1DI8, int8_t);
TEMPLATE(MemRef1DF64, double);
TEMPLATE(MemRef1DF32, float);
#define CASE(p, i, v, P, I, V) \
if (ptrTp == (p) && indTp == (i) && valTp == (v)) { \
SparseTensorCOO<V> *tensor = nullptr; \
@@ -544,35 +523,37 @@ TEMPLATE(MemRef1DF32, float);
perm); \
}
#define IMPL1(REF, NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(REF *ref, void *tensor) { \
#define IMPL1(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor) { \
std::vector<TYPE> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \
ref->base = ref->data = v->data(); \
ref->off = 0; \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
#define IMPL2(REF, NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(REF *ref, void *tensor, uint64_t d) { \
#define IMPL2(NAME, TYPE, LIB) \
void _mlir_ciface_##NAME(StridedMemRefType<TYPE, 1> *ref, void *tensor, \
uint64_t d) { \
std::vector<TYPE> *v; \
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
ref->base = ref->data = v->data(); \
ref->off = 0; \
ref->basePtr = ref->data = v->data(); \
ref->offset = 0; \
ref->sizes[0] = v->size(); \
ref->strides[0] = 1; \
}
#define IMPL3(NAME, TYPE) \
void *_mlir_ciface_##NAME(void *tensor, TYPE value, MemRef1DU64 *iref, \
MemRef1DU64 *pref) { \
void *_mlir_ciface_##NAME(void *tensor, TYPE value, \
StridedMemRefType<uint64_t, 1> *iref, \
StridedMemRefType<uint64_t, 1> *pref) { \
if (!value) \
return tensor; \
assert(iref->strides[0] == 1 && pref->strides[0] == 1); \
assert(iref->sizes[0] == pref->sizes[0]); \
const uint64_t *indx = iref->data + iref->off; \
const uint64_t *perm = pref->data + pref->off; \
const uint64_t *indx = iref->data + iref->offset; \
const uint64_t *perm = pref->data + pref->offset; \
uint64_t isize = iref->sizes[0]; \
std::vector<uint64_t> indices(isize); \
for (uint64_t r = 0; r < isize; r++) \
@@ -599,16 +580,18 @@ enum PrimaryTypeEnum : uint64_t {
/// 1 : ptr contains coordinate scheme to assign to new storage
/// 2 : returns empty coordinate scheme to fill (call back 1 to setup)
/// 3 : returns coordinate scheme from storage in ptr (call back 1 to convert)
void *_mlir_ciface_newSparseTensor(MemRef1DU8 *aref, // NOLINT
MemRef1DU64 *sref, MemRef1DU64 *pref,
uint64_t ptrTp, uint64_t indTp,
uint64_t valTp, uint32_t action, void *ptr) {
void *
_mlir_ciface_newSparseTensor(StridedMemRefType<uint8_t, 1> *aref, // NOLINT
StridedMemRefType<uint64_t, 1> *sref,
StridedMemRefType<uint64_t, 1> *pref,
uint64_t ptrTp, uint64_t indTp, uint64_t valTp,
uint32_t action, void *ptr) {
assert(aref->strides[0] == 1 && sref->strides[0] == 1 &&
pref->strides[0] == 1);
assert(aref->sizes[0] == sref->sizes[0] && sref->sizes[0] == pref->sizes[0]);
const uint8_t *sparsity = aref->data + aref->off;
const uint64_t *sizes = sref->data + sref->off;
const uint64_t *perm = pref->data + pref->off;
const uint8_t *sparsity = aref->data + aref->offset;
const uint64_t *sizes = sref->data + sref->offset;
const uint64_t *perm = pref->data + pref->offset;
uint64_t size = aref->sizes[0];
// Double matrices with all combinations of overhead storage.
@@ -668,22 +651,22 @@ void *_mlir_ciface_newSparseTensor(MemRef1DU8 *aref, // NOLINT
}
/// Methods that provide direct access to pointers, indices, and values.
IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers)
IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
IMPL2(MemRef1DU16, sparsePointers16, uint16_t, getPointers)
IMPL2(MemRef1DU8, sparsePointers8, uint8_t, getPointers)
IMPL2(MemRef1DU64, sparseIndices, uint64_t, getIndices)
IMPL2(MemRef1DU64, sparseIndices64, uint64_t, getIndices)
IMPL2(MemRef1DU32, sparseIndices32, uint32_t, getIndices)
IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues)
IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
IMPL2(sparsePointers, uint64_t, getPointers)
IMPL2(sparsePointers64, uint64_t, getPointers)
IMPL2(sparsePointers32, uint32_t, getPointers)
IMPL2(sparsePointers16, uint16_t, getPointers)
IMPL2(sparsePointers8, uint8_t, getPointers)
IMPL2(sparseIndices, uint64_t, getIndices)
IMPL2(sparseIndices64, uint64_t, getIndices)
IMPL2(sparseIndices32, uint32_t, getIndices)
IMPL2(sparseIndices16, uint16_t, getIndices)
IMPL2(sparseIndices8, uint8_t, getIndices)
IMPL1(sparseValuesF64, double, getValues)
IMPL1(sparseValuesF32, float, getValues)
IMPL1(sparseValuesI64, int64_t, getValues)
IMPL1(sparseValuesI32, int32_t, getValues)
IMPL1(sparseValuesI16, int16_t, getValues)
IMPL1(sparseValuesI8, int8_t, getValues)
/// Helper to add value to coordinate scheme, one per value type.
IMPL3(addEltF64, double)