isl: detect vector parallelism

llvm-svn: 170138
This commit is contained in:
Sebastian Pop
2012-12-13 16:52:41 +00:00
parent 62167ca254
commit e252c85545
4 changed files with 67 additions and 14 deletions

View File

@@ -81,6 +81,9 @@ static void IslAstUserFree(void *User)
struct AstNodeUserInfo {
// The node is the outermost parallel loop.
int IsOutermostParallel;
// The node is the innermost parallel loop.
int IsInnermostParallel;
};
// Temporary information used when building the ast.
@@ -92,16 +95,22 @@ struct AstBuildUserInfo {
int InParallelFor;
};
// Print a loop annotated with OpenMP pragmas.
// Print a loop annotated with OpenMP or vector pragmas.
static __isl_give isl_printer *
printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *PrintOptions,
AstNodeUserInfo *Info) {
if (Info && Info->IsOutermostParallel) {
Printer = isl_printer_start_line(Printer);
if (Info->IsOutermostParallel)
if (Info) {
if (Info->IsInnermostParallel) {
Printer = isl_printer_start_line(Printer);
Printer = isl_printer_print_str(Printer, "#pragma simd");
Printer = isl_printer_end_line(Printer);
}
if (Info->IsOutermostParallel) {
Printer = isl_printer_start_line(Printer);
Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
Printer = isl_printer_end_line(Printer);
Printer = isl_printer_end_line(Printer);
}
}
return isl_ast_node_for_print(Node, Printer, PrintOptions);
}
@@ -126,6 +135,7 @@ static struct AstNodeUserInfo *allocateAstNodeUserInfo() {
struct AstNodeUserInfo *NodeInfo;
NodeInfo = (struct AstNodeUserInfo *) malloc(sizeof(struct AstNodeUserInfo));
NodeInfo->IsOutermostParallel = 0;
NodeInfo->IsInnermostParallel = 0;
return NodeInfo;
}
@@ -223,6 +233,46 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
return Id;
}
// Returns 0 when Node contains loops, otherwise returns -1. This search
// function uses ISL's way to iterate over lists of isl_ast_nodes with
// isl_ast_node_list_foreach. Please use the single argument wrapper function
// that returns a bool instead of using this function directly.
static int containsLoops(__isl_take isl_ast_node *Node, void *User) {
if (!Node)
return -1;
switch (isl_ast_node_get_type(Node)) {
case isl_ast_node_for:
isl_ast_node_free(Node);
return 0;
case isl_ast_node_block: {
isl_ast_node_list *List = isl_ast_node_block_get_children(Node);
int Res = isl_ast_node_list_foreach(List, &containsLoops, NULL);
isl_ast_node_list_free(List);
isl_ast_node_free(Node);
return Res;
}
case isl_ast_node_if: {
int Res = -1;
if (0 == containsLoops(isl_ast_node_if_get_then(Node), NULL) ||
(isl_ast_node_if_has_else(Node) &&
0 == containsLoops(isl_ast_node_if_get_else(Node), NULL)))
Res = 0;
isl_ast_node_free(Node);
return Res;
}
case isl_ast_node_user:
default:
isl_ast_node_free(Node);
return -1;
}
}
// Returns true when Node contains loops.
static bool containsLoops(__isl_take isl_ast_node *Node) {
return 0 == containsLoops(Node, NULL);
}
// This method is executed after the construction of a for node.
//
// It performs the following actions:
@@ -233,17 +283,17 @@ static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
static __isl_give isl_ast_node *
astBuildAfterFor(__isl_take isl_ast_node *Node,
__isl_keep isl_ast_build *Build, void *User) {
isl_id *Id;
struct AstBuildUserInfo *BuildInfo;
struct AstNodeUserInfo *Info;
Id = isl_ast_node_get_annotation(Node);
isl_id *Id = isl_ast_node_get_annotation(Node);
if (!Id)
return Node;
Info = (struct AstNodeUserInfo *) isl_id_get_user(Id);
if (Info && Info->IsOutermostParallel) {
BuildInfo = (struct AstBuildUserInfo *) User;
BuildInfo->InParallelFor = 0;
struct AstNodeUserInfo *Info = (struct AstNodeUserInfo *) isl_id_get_user(Id);
struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *) User;
if (Info) {
if (Info->IsOutermostParallel)
BuildInfo->InParallelFor = 0;
if (!containsLoops(isl_ast_node_for_get_body(Node)))
if (astScheduleDimIsParallel(Build, BuildInfo->Deps))
Info->IsInnermostParallel = 1;
}
isl_id_free(Id);

View File

@@ -50,6 +50,7 @@ ret:
; memory accesses, that would happen if n >= 1024.
;
; CHECK: for (int c1 = 0; c1 < n; c1 += 1)
; CHECK: #pragma simd
; CHECK: #pragma omp parallel for
; CHECK: for (int c3 = 0; c3 < n; c3 += 1)
; CHECK: Stmt_loop_body(c1, c3);

View File

@@ -41,6 +41,7 @@ ret:
}
; CHECK: for (int c1 = 0; c1 < n; c1 += 1)
; CHECK: #pragma simd
; CHECK: #pragma omp parallel for
; CHECK: for (int c3 = 0; c3 < n; c3 += 1)
; CHECK: Stmt_loop_body(c1, c3);

View File

@@ -30,6 +30,7 @@ ret:
ret void
}
; CHECK: #pragma simd
; CHECK: #pragma omp parallel for
; CHECK: for (int c1 = 0; c1 < n; c1 += 1)
; CHECK: Stmt_loop_body(c1)