Files
llvm/mlir/test/Bindings/Python/insertion_point.py
Stella Laurenzo c645ea5e29 Add InsertionPoint and context managers to the Python API.
* Removes index based insertion. All insertion now happens through the insertion point.
* Introduces thread local context managers for implicit creation relative to an insertion point.
* Introduces (but does not yet use) binding the Context to the thread local context stack. Intent is to refactor all methods to take context optionally and have them use the default if available.
* Adds C APIs for mlirOperationGetParentOperation(), mlirOperationGetBlock() and mlirBlockGetTerminator().
* Removes an assert in PyOperation creation that was incorrectly constraining. There is already a TODO to rework the keepAlive field that it was guarding and without the assert, it is no worse than the current state.

Differential Revision: https://reviews.llvm.org/D90368
2020-10-29 17:50:13 -07:00

153 lines
4.0 KiB
Python

# RUN: %PYTHON %s | FileCheck %s
import gc
import io
import itertools
from mlir.ir import *
def run(f):
print("\nTEST:", f.__name__)
f()
gc.collect()
assert Context._get_live_count() == 0
# CHECK-LABEL: TEST: test_insert_at_block_end
def test_insert_at_block_end():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block)
ip.insert(ctx.create_operation("custom.op2", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_end)
# CHECK-LABEL: TEST: test_insert_before_operation
def test_insert_before_operation():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint(entry_block.operations[1])
ip.insert(ctx.create_operation("custom.op3", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op3"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_before_operation)
# CHECK-LABEL: TEST: test_insert_at_block_begin
def test_insert_at_block_begin():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op2"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_begin(entry_block)
ip.insert(ctx.create_operation("custom.op1", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_block_begin)
# CHECK-LABEL: TEST: test_insert_at_block_begin_empty
def test_insert_at_block_begin_empty():
# TODO: Write this test case when we can create such a situation.
pass
run(test_insert_at_block_begin_empty)
# CHECK-LABEL: TEST: test_insert_at_terminator
def test_insert_at_terminator():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
return
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
ip = InsertionPoint.at_block_terminator(entry_block)
ip.insert(ctx.create_operation("custom.op2", loc))
# CHECK: "custom.op1"
# CHECK: "custom.op2"
module.operation.print()
run(test_insert_at_terminator)
# CHECK-LABEL: TEST: test_insert_at_block_terminator_missing
def test_insert_at_block_terminator_missing():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
try:
ip = InsertionPoint.at_block_terminator(entry_block)
except ValueError as e:
# CHECK: Block has no terminator
print(e)
else:
assert False, "Expected exception"
run(test_insert_at_block_terminator_missing)
# CHECK-LABEL: TEST: test_insertion_point_context
def test_insertion_point_context():
ctx = Context()
ctx.allow_unregistered_dialects = True
loc = ctx.get_unknown_location()
module = ctx.parse_module(r"""
func @foo() -> () {
"custom.op1"() : () -> ()
}
""")
entry_block = module.body.operations[0].regions[0].blocks[0]
with InsertionPoint(entry_block):
ctx.create_operation("custom.op2", loc)
with InsertionPoint.at_block_begin(entry_block):
ctx.create_operation("custom.opa", loc)
ctx.create_operation("custom.opb", loc)
ctx.create_operation("custom.op3", loc)
# CHECK: "custom.opa"
# CHECK: "custom.opb"
# CHECK: "custom.op1"
# CHECK: "custom.op2"
# CHECK: "custom.op3"
module.operation.print()
run(test_insertion_point_context)