mirror of
https://github.com/intel/llvm.git
synced 2026-01-13 19:08:21 +08:00
* Removes index based insertion. All insertion now happens through the insertion point. * Introduces thread local context managers for implicit creation relative to an insertion point. * Introduces (but does not yet use) binding the Context to the thread local context stack. Intent is to refactor all methods to take context optionally and have them use the default if available. * Adds C APIs for mlirOperationGetParentOperation(), mlirOperationGetBlock() and mlirBlockGetTerminator(). * Removes an assert in PyOperation creation that was incorrectly constraining. There is already a TODO to rework the keepAlive field that it was guarding and without the assert, it is no worse than the current state. Differential Revision: https://reviews.llvm.org/D90368
153 lines
4.0 KiB
Python
153 lines
4.0 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
import gc
|
|
import io
|
|
import itertools
|
|
from mlir.ir import *
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_at_block_end
|
|
def test_insert_at_block_end():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op1"() : () -> ()
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
ip = InsertionPoint(entry_block)
|
|
ip.insert(ctx.create_operation("custom.op2", loc))
|
|
# CHECK: "custom.op1"
|
|
# CHECK: "custom.op2"
|
|
module.operation.print()
|
|
|
|
run(test_insert_at_block_end)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_before_operation
|
|
def test_insert_before_operation():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op1"() : () -> ()
|
|
"custom.op2"() : () -> ()
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
ip = InsertionPoint(entry_block.operations[1])
|
|
ip.insert(ctx.create_operation("custom.op3", loc))
|
|
# CHECK: "custom.op1"
|
|
# CHECK: "custom.op3"
|
|
# CHECK: "custom.op2"
|
|
module.operation.print()
|
|
|
|
run(test_insert_before_operation)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_at_block_begin
|
|
def test_insert_at_block_begin():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op2"() : () -> ()
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
ip = InsertionPoint.at_block_begin(entry_block)
|
|
ip.insert(ctx.create_operation("custom.op1", loc))
|
|
# CHECK: "custom.op1"
|
|
# CHECK: "custom.op2"
|
|
module.operation.print()
|
|
|
|
run(test_insert_at_block_begin)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
|
|
def test_insert_at_block_begin_empty():
|
|
# TODO: Write this test case when we can create such a situation.
|
|
pass
|
|
|
|
run(test_insert_at_block_begin_empty)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_at_terminator
|
|
def test_insert_at_terminator():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op1"() : () -> ()
|
|
return
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
ip = InsertionPoint.at_block_terminator(entry_block)
|
|
ip.insert(ctx.create_operation("custom.op2", loc))
|
|
# CHECK: "custom.op1"
|
|
# CHECK: "custom.op2"
|
|
module.operation.print()
|
|
|
|
run(test_insert_at_terminator)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
|
|
def test_insert_at_block_terminator_missing():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op1"() : () -> ()
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
try:
|
|
ip = InsertionPoint.at_block_terminator(entry_block)
|
|
except ValueError as e:
|
|
# CHECK: Block has no terminator
|
|
print(e)
|
|
else:
|
|
assert False, "Expected exception"
|
|
|
|
run(test_insert_at_block_terminator_missing)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_insertion_point_context
|
|
def test_insertion_point_context():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
loc = ctx.get_unknown_location()
|
|
module = ctx.parse_module(r"""
|
|
func @foo() -> () {
|
|
"custom.op1"() : () -> ()
|
|
}
|
|
""")
|
|
entry_block = module.body.operations[0].regions[0].blocks[0]
|
|
with InsertionPoint(entry_block):
|
|
ctx.create_operation("custom.op2", loc)
|
|
with InsertionPoint.at_block_begin(entry_block):
|
|
ctx.create_operation("custom.opa", loc)
|
|
ctx.create_operation("custom.opb", loc)
|
|
ctx.create_operation("custom.op3", loc)
|
|
# CHECK: "custom.opa"
|
|
# CHECK: "custom.opb"
|
|
# CHECK: "custom.op1"
|
|
# CHECK: "custom.op2"
|
|
# CHECK: "custom.op3"
|
|
module.operation.print()
|
|
|
|
run(test_insertion_point_context)
|