[mlir][EmitC] Do not inline expressions used by ops with the CExpression trait (#93691)

Currently an expression is inlined without emitting enclosing
parentheses regardless of the context of the user. This could led to
wrong evaluation order depending on the precedence of both expressions.
If the inlining is intended, the user operation should be merged into
the expression op.

Fixes #93470.
This commit is contained in:
Simon Camphausen
2024-06-04 13:14:08 +02:00
committed by GitHub
parent fc5254c8ac
commit a934ddcf7e
2 changed files with 83 additions and 3 deletions

View File

@@ -301,9 +301,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
if (isa<emitc::SubscriptOp>(user))
return false;
// Do not inline expressions used by other expressions, as any desired
// expression folding was taken care of by transformations.
return !user->getParentOfType<ExpressionOp>();
// Do not inline expressions used by ops with the CExpression trait. If this
// was intended, the user could have been merged into the expression op.
return !user->hasTrait<OpTrait::emitc::CExpression>();
}
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,

View File

@@ -100,6 +100,86 @@ func.func @parentheses_for_same_precedence(%arg0: i32, %arg1: i32, %arg2: i32) -
return %e : i32
}
// CPP-DEFAULT: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: int32_t [[VAL_4:v[0-9]+]] = 0;
// CPP-DEFAULT-NEXT: int32_t [[EXP_0:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_1:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_2:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: int32_t [[EXP_3:v[0-9]+]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: bool [[CAST:v[0-9]+]] = (bool) [[EXP_0]];
// CPP-DEFAULT-NEXT: int32_t [[ADD:v[0-9]+]] = [[EXP_1]] + [[VAL_4]];
// CPP-DEFAULT-NEXT: int32_t [[CALL:v[0-9]+]] = bar([[EXP_2]], [[VAL_4]]);
// CPP-DEFAULT-NEXT: int32_t [[COND:v[0-9]+]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
// CPP-DEFAULT-NEXT: int32_t [[VAR:v[0-9]+]];
// CPP-DEFAULT-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DEFAULT-NEXT: }
// CPP-DECLTOP: int32_t user_with_expression_trait(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]]) {
// CPP-DECLTOP-NEXT: int32_t [[VAL_4:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[EXP_0:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[EXP_1:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[EXP_2:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[EXP_3:v[0-9]+]];
// CPP-DECLTOP-NEXT: bool [[CAST:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[ADD:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[CALL:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[COND:v[0-9]+]];
// CPP-DECLTOP-NEXT: int32_t [[VAR:v[0-9]+]];
// CPP-DECLTOP-NEXT: [[VAL_4]] = 0;
// CPP-DECLTOP-NEXT: [[EXP_0]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: [[EXP_1]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: [[EXP_2]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: [[EXP_3]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: [[CAST]] = (bool) [[EXP_0]];
// CPP-DECLTOP-NEXT: [[ADD]] = [[EXP_1]] + [[VAL_4]];
// CPP-DECLTOP-NEXT: [[CALL]] = bar([[EXP_2]], [[VAL_4]]);
// CPP-DECLTOP-NEXT: [[COND]] = [[CAST]] ? [[EXP_3]] : [[VAL_4]];
// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: [[VAR]] = [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: return [[VAL_3]] / ([[VAL_1]] * [[VAL_2]]);
// CPP-DECLTOP-NEXT: }
func.func @user_with_expression_trait(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 {
%c0 = "emitc.constant"() {value = 0 : i32} : () -> i32
%e0 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%e1 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%e2 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%e3 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%e4 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%e5 = emitc.expression : i32 {
%0 = emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%1 = emitc.div %arg2, %0 : (i32, i32) -> i32
emitc.yield %1 : i32
}
%cast = emitc.cast %e0 : i32 to i1
%add = emitc.add %e1, %c0 : (i32, i32) -> i32
%call = emitc.call_opaque "bar" (%e2, %c0) : (i32, i32) -> (i32)
%cond = emitc.conditional %cast, %e3, %c0 : i32
%var = "emitc.variable"() {value = #emitc.opaque<"">} : () -> i32
emitc.assign %e4 : i32 to %var : i32
return %e5 : i32
}
// CPP-DEFAULT: int32_t multiple_uses(int32_t [[VAL_1:v[0-9]+]], int32_t [[VAL_2:v[0-9]+]], int32_t [[VAL_3:v[0-9]+]], int32_t [[VAL_4:v[0-9]+]]) {
// CPP-DEFAULT-NEXT: bool [[VAL_5:v[0-9]+]] = bar([[VAL_1]] * [[VAL_2]], [[VAL_3]]) - [[VAL_4]] < [[VAL_2]];
// CPP-DEFAULT-NEXT: int32_t [[VAL_6:v[0-9]+]];