mirror of
https://github.com/intel/llvm.git
synced 2026-01-16 05:32:28 +08:00
[mlir][sparse] cleanup merger test, add header (#70279)
This commit is contained in:
@@ -1,3 +1,11 @@
|
||||
//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
|
||||
#include "llvm/Support/Compiler.h"
|
||||
#include "gmock/gmock.h"
|
||||
@@ -73,56 +81,43 @@ namespace {
|
||||
/// Helper classes/functions for testing Merger.
|
||||
///
|
||||
|
||||
/// Simple recursive data structure used to match expressions in `Merger`.
|
||||
struct Pattern;
|
||||
/// Since the patterns we need are rather small and short-lived, we use
|
||||
/// `Pattern const&` for "pointers" to patterns, rather than using
|
||||
/// something more elaborate like `std::shared_ptr<Pattern> const&`.
|
||||
using PatternRef = const Pattern &;
|
||||
struct Pattern {
|
||||
/// Simple recursive data structure used to match expressions in `Merger`,
|
||||
/// which uses const references into the short-lived data strucutures.
|
||||
struct Match {
|
||||
struct Children {
|
||||
Children(PatternRef e0, PatternRef e1) : e0(e0), e1(e1) {}
|
||||
PatternRef e0;
|
||||
PatternRef e1;
|
||||
Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
|
||||
const Match &e0;
|
||||
const Match &e1;
|
||||
};
|
||||
|
||||
TensorExp::Kind kind;
|
||||
|
||||
union {
|
||||
/// Expressions representing tensors simply have a tensor number.
|
||||
TensorId tid;
|
||||
|
||||
/// Tensor operations point to their children.
|
||||
Children children;
|
||||
};
|
||||
|
||||
/// Constructors.
|
||||
/// Rather than using these, please use the readable builder
|
||||
/// functions below to make tests more readable.
|
||||
Pattern() : kind(TensorExp::Kind::kSynZero) {}
|
||||
Pattern(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
|
||||
Pattern(TensorExp::Kind kind, PatternRef e0, PatternRef e1)
|
||||
Match() : kind(TensorExp::Kind::kSynZero) {}
|
||||
Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
|
||||
Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
|
||||
: kind(kind), children(e0, e1) {
|
||||
assert(kind >= TensorExp::Kind::kMulF);
|
||||
}
|
||||
|
||||
TensorExp::Kind kind;
|
||||
union {
|
||||
TensorId tid;
|
||||
Children children;
|
||||
};
|
||||
};
|
||||
|
||||
///
|
||||
/// Readable Pattern builder functions.
|
||||
/// Readable Match builder functions.
|
||||
/// These should be preferred over the actual constructors.
|
||||
///
|
||||
|
||||
static Pattern tensorPattern(TensorId tid) { return Pattern(tid); }
|
||||
static Pattern synZeroPattern() { return Pattern(); }
|
||||
static Match tensorMatch(TensorId tid) { return Match(tid); }
|
||||
static Match synZeroMatch() { return Match(); }
|
||||
|
||||
#define IMPL_BINOP_PATTERN(OP, KIND) \
|
||||
LLVM_ATTRIBUTE_UNUSED static Pattern OP##Pattern(PatternRef e0, \
|
||||
PatternRef e1) { \
|
||||
return Pattern(KIND, e0, e1); \
|
||||
LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
|
||||
const Match &e1) { \
|
||||
return Match(KIND, e0, e1); \
|
||||
}
|
||||
|
||||
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
|
||||
|
||||
#undef IMPL_BINOP_PATTERN
|
||||
|
||||
class MergerTestBase : public ::testing::Test {
|
||||
@@ -150,9 +145,7 @@ protected:
|
||||
LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
|
||||
return merger.addExp(KIND, e0, e1); \
|
||||
}
|
||||
|
||||
FOREVERY_BINOP(IMPL_BINOP_EXPR)
|
||||
|
||||
#undef IMPL_BINOP_EXPR
|
||||
|
||||
///
|
||||
@@ -168,7 +161,7 @@ protected:
|
||||
/// ordering within groups. If `simple` is true, then compare the
|
||||
/// `lat.simple` field instead to test the result after optimization.
|
||||
bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
|
||||
PatternRef pattern, const BitVector &bits,
|
||||
const Match &pattern, const BitVector &bits,
|
||||
bool simple) {
|
||||
for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
|
||||
if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
|
||||
@@ -180,13 +173,13 @@ protected:
|
||||
|
||||
/// Wrapper over latPointWithinRange for readability of tests.
|
||||
void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
|
||||
PatternRef pattern, const BitVector &bits,
|
||||
const Match &pattern, const BitVector &bits,
|
||||
bool simple = false) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
|
||||
}
|
||||
|
||||
/// Wrapper over expectLatPointWithinRange for a single lat point.
|
||||
void expectLatPoint(LatSetId s, unsigned lo, PatternRef pattern,
|
||||
void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
|
||||
const BitVector &bits, bool simple = false) {
|
||||
EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
|
||||
}
|
||||
@@ -216,7 +209,7 @@ protected:
|
||||
/// Compares expressions for equality. Equality is defined recursively as:
|
||||
/// - Operations are equal if they have the same kind and children.
|
||||
/// - Leaf tensors are equal if they refer to the same tensor.
|
||||
bool compareExpression(ExprId e, PatternRef pattern) {
|
||||
bool compareExpression(ExprId e, const Match &pattern) {
|
||||
const auto &tensorExp = merger.exp(e);
|
||||
if (tensorExp.kind != pattern.kind)
|
||||
return false;
|
||||
@@ -424,21 +417,19 @@ protected:
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
const auto t2 = tid(2); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
PatternRef p2 = tensorPattern(t2); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
const Match &p2 = tensorMatch(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t1}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
|
||||
|
||||
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
|
||||
@@ -461,21 +452,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
|
||||
const auto t1 = tid(1); \
|
||||
const auto t2 = tid(2); \
|
||||
const auto t3 = tid(3); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
PatternRef p2 = tensorPattern(t2); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
const Match &p2 = tensorMatch(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t3}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
|
||||
|
||||
/// Vector addition (disjunction) of 2 vectors. i.e.;
|
||||
@@ -499,26 +488,24 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
|
||||
true); \
|
||||
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
|
||||
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_DISJ
|
||||
|
||||
/// Vector multiplication (conjunction) of 2 vectors, i.e.;
|
||||
@@ -533,22 +520,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
|
||||
true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ
|
||||
|
||||
/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
|
||||
@@ -567,29 +552,27 @@ FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
const auto t2 = tid(2); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
PatternRef p2 = tensorPattern(t2); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
const Match &p2 = tensorMatch(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
|
||||
expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, 0, DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, CONJ##Pattern(p0, p1), \
|
||||
expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_DISJ
|
||||
|
||||
/// Vector addition (disjunction) then addition (disjunction), i.e.;
|
||||
@@ -612,19 +595,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
const auto t2 = tid(2); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
PatternRef p2 = tensorPattern(t2); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
const Match &p2 = tensorMatch(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 7); \
|
||||
expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
|
||||
loopsToBits({{l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
|
||||
@@ -632,21 +615,19 @@ FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 7); \
|
||||
expectLatPoint(s, 0, DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p1, p2), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
|
||||
loopsToBits({{l0, t1}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Pattern(p0, p2), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ1##Pattern(p0, p1), \
|
||||
expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_DISJ_DISJ
|
||||
|
||||
/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
|
||||
@@ -663,21 +644,19 @@ FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
const auto t2 = tid(2); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
PatternRef p2 = tensorPattern(t2); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
const Match &p2 = tensorMatch(t2); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
|
||||
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_CONJ_CONJ
|
||||
|
||||
/// Vector addition (disjunction) of 2 vectors, i.e.;
|
||||
@@ -702,25 +681,23 @@ FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 3); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
|
||||
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 2); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}}), true); \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
|
||||
true); \
|
||||
expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
|
||||
|
||||
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
|
||||
|
||||
/// Vector multiplication (conjunction) of 2 vectors, i.e.:
|
||||
@@ -740,20 +717,20 @@ FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
|
||||
const auto l0 = lid(0); \
|
||||
const auto t0 = tid(0); \
|
||||
const auto t1 = tid(1); \
|
||||
PatternRef p0 = tensorPattern(t0); \
|
||||
PatternRef p1 = tensorPattern(t1); \
|
||||
const Match &p0 = tensorMatch(t0); \
|
||||
const Match &p1 = tensorMatch(t1); \
|
||||
auto s = merger.buildLattices(e, l0); \
|
||||
\
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), \
|
||||
loopsToBits({{l0, t0}, {l0, t1}})); \
|
||||
\
|
||||
s = merger.optimizeSet(s); \
|
||||
expectNumLatPoints(s, 1); \
|
||||
expectLatPoint(s, 0, OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), true); \
|
||||
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \
|
||||
}
|
||||
|
||||
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
|
||||
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
|
||||
|
||||
/// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
|
||||
/// a(i) = b(i) + c(i)
|
||||
@@ -775,20 +752,20 @@ TEST_F(MergerTest3T1L, vector_cmp) {
|
||||
const auto l0 = lid(0);
|
||||
const auto t0 = tid(0);
|
||||
const auto t1 = tid(1);
|
||||
PatternRef zero = synZeroPattern();
|
||||
PatternRef p0 = tensorPattern(t0);
|
||||
PatternRef p1 = tensorPattern(t1);
|
||||
const Match &zero = synZeroMatch();
|
||||
const Match &p0 = tensorMatch(t0);
|
||||
const Match &p1 = tensorMatch(t1);
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
|
||||
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
s = merger.optimizeSet(s);
|
||||
expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
|
||||
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
}
|
||||
|
||||
@@ -813,19 +790,17 @@ TEST_F(MergerTest3T1LD, vector_cmp) {
|
||||
const auto l0 = lid(0);
|
||||
const auto t0 = tid(0);
|
||||
const auto t1 = tid(1);
|
||||
PatternRef zero = synZeroPattern();
|
||||
PatternRef p0 = tensorPattern(t0);
|
||||
PatternRef p1 = tensorPattern(t1);
|
||||
const Match &zero = synZeroMatch();
|
||||
const Match &p0 = tensorMatch(t0);
|
||||
const Match &p1 = tensorMatch(t1);
|
||||
auto s = merger.buildLattices(e, l0);
|
||||
expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(p0, zero),
|
||||
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
|
||||
loopsToBits({{l0, t0}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
s = merger.optimizeSet(s);
|
||||
expectLatPoint(s, 0, cmpiPattern(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiPattern(zero, p1),
|
||||
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
|
||||
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
|
||||
loopsToBits({{l0, t1}}));
|
||||
}
|
||||
|
||||
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
|
||||
|
||||
Reference in New Issue
Block a user