diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index dda3b3827c7a..026ee035fcf5 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -1302,20 +1302,20 @@ inline BinaryOpc_match m_Not(const ValTy &V) { template struct ReassociatableOpc_match { unsigned Opcode; std::tuple Patterns; + constexpr static size_t NumPatterns = + std::tuple_size_v>; ReassociatableOpc_match(unsigned Opcode, const PatternTs &...Patterns) : Opcode(Opcode), Patterns(Patterns...) {} template bool match(const MatchContext &Ctx, SDValue N) { - constexpr size_t NumPatterns = std::tuple_size_v>; - - SmallVector Leaves; - collectLeaves(N, Leaves); - if (Leaves.size() != NumPatterns) + std::array Leaves; + size_t LeavesIdx = 0; + if (!(collectLeaves(N, Leaves, LeavesIdx) && (LeavesIdx == NumPatterns))) return false; - SmallBitVector Used(NumPatterns); + Bitset Used; return std::apply( [&](auto &...P) -> bool { return reassociatableMatchHelper(Ctx, Leaves, Used, P...); @@ -1323,36 +1323,41 @@ template struct ReassociatableOpc_match { Patterns); } - void collectLeaves(SDValue V, SmallVector &Leaves) { + bool collectLeaves(SDValue V, std::array &Leaves, + std::size_t &LeafIdx) { if (V->getOpcode() == Opcode) { for (size_t I = 0, N = V->getNumOperands(); I < N; I++) - collectLeaves(V->getOperand(I), Leaves); + if ((LeafIdx == NumPatterns) || + !collectLeaves(V->getOperand(I), Leaves, LeafIdx)) + return false; } else { - Leaves.emplace_back(V); + Leaves[LeafIdx] = V; + LeafIdx++; } + return true; } // Searchs for a matching leaf for every sub-pattern. template [[nodiscard]] inline bool reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef Leaves, - SmallBitVector &Used, PatternHd &HeadPattern, + Bitset &Used, PatternHd &HeadPattern, PatternTl &...TailPatterns) { for (size_t Match = 0, N = Used.size(); Match < N; Match++) { if (Used[Match] || !(sd_context_match(Leaves[Match], Ctx, HeadPattern))) continue; - Used[Match] = true; + Used.set(Match); if (reassociatableMatchHelper(Ctx, Leaves, Used, TailPatterns...)) return true; - Used[Match] = false; + Used.reset(Match); } return false; } template - [[nodiscard]] inline bool reassociatableMatchHelper(const MatchContext &Ctx, - ArrayRef Leaves, - SmallBitVector &Used) { + [[nodiscard]] inline bool + reassociatableMatchHelper(const MatchContext &Ctx, ArrayRef Leaves, + Bitset &Used) { return true; } }; diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index 4fcd3fcb8c5c..1afc034dd7b9 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -803,6 +803,8 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) { SDValue ADD = DAG->getNode(ISD::ADD, DL, Int32VT, ADD01, ADD23); EXPECT_FALSE(sd_match(ADD01, m_ReassociatableAdd(m_Value()))); + EXPECT_FALSE( + sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value(), m_Value()))); EXPECT_TRUE(sd_match(ADD01, m_ReassociatableAdd(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(ADD23, m_ReassociatableAdd(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(