mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
This is a follow-up PR for #162699. Currently, in the function where we define rewrite patterns, the `op` we receive is of type `ir.Operation` rather than a specific `OpView` type (such as `arith.AddIOp`). This means we can’t conveniently access certain parts of the operation — for example, we need to use `op.operands[0]` instead of `op.lhs`. The following example code illustrates this situation. ```python def to_muli(op, rewriter): # op is typed ir.Operation instead of arith.AddIOp pass patterns.add(arith.AddIOp, to_muli) ``` In this PR, we convert the operation to its corresponding `OpView` subclass before invoking the rewrite pattern callback, making it much easier to write patterns. --------- Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
# RUN: %PYTHON %s 2>&1 | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
from mlir.passmanager import *
|
|
from mlir.dialects.builtin import ModuleOp
|
|
from mlir.dialects import arith
|
|
from mlir.rewrite import *
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
|
|
|
|
# CHECK-LABEL: TEST: testRewritePattern
|
|
@run
|
|
def testRewritePattern():
|
|
def to_muli(op, rewriter):
|
|
with rewriter.ip:
|
|
assert isinstance(op, arith.AddIOp)
|
|
new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
|
|
rewriter.replace_op(op, new_op.owner)
|
|
|
|
def constant_1_to_2(op, rewriter):
|
|
c = op.value.value
|
|
if c != 1:
|
|
return True # failed to match
|
|
with rewriter.ip:
|
|
new_op = arith.constant(op.type, 2, loc=op.location)
|
|
rewriter.replace_op(op, [new_op])
|
|
|
|
with Context():
|
|
patterns = RewritePatternSet()
|
|
patterns.add(arith.AddIOp, to_muli)
|
|
patterns.add(arith.ConstantOp, constant_1_to_2)
|
|
frozen = patterns.freeze()
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @add(%a: i64, %b: i64) -> i64 {
|
|
%sum = arith.addi %a, %b : i64
|
|
return %sum : i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
# CHECK: %0 = arith.muli %arg0, %arg1 : i64
|
|
# CHECK: return %0 : i64
|
|
print(module)
|
|
|
|
module = ModuleOp.parse(
|
|
r"""
|
|
module {
|
|
func.func @const() -> (i64, i64) {
|
|
%0 = arith.constant 1 : i64
|
|
%1 = arith.constant 3 : i64
|
|
return %0, %1 : i64, i64
|
|
}
|
|
}
|
|
"""
|
|
)
|
|
|
|
apply_patterns_and_fold_greedily(module, frozen)
|
|
# CHECK: %c2_i64 = arith.constant 2 : i64
|
|
# CHECK: %c3_i64 = arith.constant 3 : i64
|
|
# CHECK: return %c2_i64, %c3_i64 : i64, i64
|
|
print(module)
|