Add alignment support to linalg.buffer_alloc

This CL adds an integer attribute to linalg.buffer_alloc and lowering to LLVM.
The alignment is constrained to be a positive power of 2.

Lowering to LLVM produces the pattern:
```
%[[alloc:.*]] = llvm.call @malloc(%[[s]]) : (!llvm.i64) -> !llvm<"i8*">
%[[cast:.*]] = llvm.bitcast %[[alloc]] : !llvm<"i8*"> to !llvm.i64
%[[rem:.*]] = llvm.urem %[[cast]], %[[c16]] : !llvm.i64
%[[drem:.*]] = llvm.sub %[[c16]], %[[rem]] : !llvm.i64
%[[off:.*]] = llvm.urem %[[drem]], %[[c16]] : !llvm.i64
llvm.getelementptr %{{.*}}[%[[off]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
```

where `ptr` is aligned on `align` by computing the address
`ptr + (align - ptr % align) % align`.

To allow dealloc op to still be able to free memory, additional information is needed in
the buffer type. The buffer type is thus extended with an extra i8* for the base allocation address.

PiperOrigin-RevId: 264244455
This commit is contained in:
Nicolas Vasilache
2019-08-19 14:36:49 -07:00
committed by A. Unique TensorFlower
parent 8165f181d9
commit 36f48063dd
7 changed files with 149 additions and 79 deletions

View File

@@ -41,7 +41,7 @@ class Linalg_Op<string mnemonic, list<OpTrait> traits = []> :
def BufferAllocOp :
Linalg_Op<"buffer_alloc">,
Arguments<(ins Variadic<Index>:$size)>,
Arguments<(ins Variadic<Index>:$size, OptionalAttr<I64Attr>:$alignment)>,
Results<(outs Buffer)> {
let summary = "buffer allocation operation";
let description = [{
@@ -49,11 +49,18 @@ def BufferAllocOp :
upon which a base view can be laid out to give it indexing semantics.
"buffer_alloc" takes a single argument, the size of the buffer to allocate
(in number of elements).
An optional alignment attribute may be specified in which case the actual
underlying allocation size may be increased. The base pointer is guaranteed
to be a multiple of `alignment`. Such an alignment must be a positive power
of 2.
Example:
Examples:
%0 = linalg.buffer_alloc(%arg0) : !linalg.buffer<?xf32>
%1 = linalg.buffer_alloc(%arg0) { alignment = 16 } :
!linalg.buffer<?xf32>
The size argument may be omitted if it is statically known, in which case it
must be reflected in the type.
@@ -61,12 +68,32 @@ def BufferAllocOp :
%0 = linalg.buffer_alloc() : !linalg.buffer<4xf32>
}];
let builders = [OpBuilder<
"Builder *builder, OperationState *result, BufferType bufferType", [{
result->types.push_back(bufferType);
}]
>];
let builders = [
OpBuilder<
"Builder *b, OperationState *result, BufferType bufferType", [{
result->addTypes(bufferType);
}]>,
OpBuilder<
"Builder *b, OperationState *result, BufferType bufferType, "
"unsigned alignment", [{
build(b, result, bufferType);
if (alignment != 0)
result->addAttribute(BufferAllocOp::getAlignmentAttrName(),
b->getI64IntegerAttr(alignment));
}]>,
OpBuilder<
"Builder *b, OperationState *result, BufferType bufferType, "
"Value *size, unsigned alignment", [{
if (alignment == 0)
return build(b, result, bufferType, size);
build(b, result, bufferType, size, b->getI64IntegerAttr(alignment));
}]>,
OpBuilder<
"Builder *b, OperationState *result, BufferType bufferType, Value *size",
[{ build(b, result, bufferType, size, 0); }]>
];
let extraClassDeclaration = [{
static StringRef getAlignmentAttrName() { return "alignment"; }
BufferType getBufferType() { return getType().cast<BufferType>(); }
Type getElementType() { return getBufferType().getElementType(); }
}];

View File

@@ -37,6 +37,7 @@
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -131,7 +132,11 @@ static void print(OpAsmPrinter *p, BufferAllocOp op) {
*p << op.getOperationName() << " ";
if (!llvm::empty(op.size()))
*p << *op.getOperand(0);
p->printOptionalAttrDict(op.getAttrs());
if (op.alignment().hasValue() && op.alignment()->getSExtValue() != 0)
p->printOptionalAttrDict(op.getAttrs());
else
p->printOptionalAttrDict(op.getAttrs(),
BufferAllocOp::getAlignmentAttrName());
*p << " : " << op.getBufferType();
}
@@ -160,6 +165,13 @@ static LogicalResult verify(BufferAllocOp op) {
if (op.getBufferType().getBufferSize().getValue() <= 0)
return op.emitOpError("expected nonnegative static buffer size");
}
if (op.alignment().hasValue()) {
auto align = op.alignment().getValue();
if (align.getSExtValue() < 0)
return op.emitOpError("expected positive alignment");
if (!llvm::isPowerOf2_64(align.getZExtValue()))
return op.emitOpError("expected power of 2 alignment");
}
if (!TensorType::isValidElementType(op.getElementType()))
return op.emitOpError("expected valid buffer element type");
return success();

View File

@@ -68,8 +68,10 @@ using llvm_load = ValueBuilder<LLVM::LoadOp>;
using llvm_store = OperationBuilder<LLVM::StoreOp>;
using llvm_select = ValueBuilder<LLVM::SelectOp>;
using mul = ValueBuilder<mlir::LLVM::MulOp>;
using ptrtoint = ValueBuilder<mlir::LLVM::PtrToIntOp>;
using sub = ValueBuilder<mlir::LLVM::SubOp>;
using undef = ValueBuilder<mlir::LLVM::UndefOp>;
using urem = ValueBuilder<mlir::LLVM::URemOp>;
using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
using llvm_return = OperationBuilder<LLVM::ReturnOp>;
@@ -99,12 +101,14 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
//
// template <typename Elem, size_t Rank>
// struct {
// void *baseAlloc;
// Elem *ptr;
// int64_t size;
// };
if (auto bufferType = t.dyn_cast<BufferType>()) {
auto voidPtrTy = LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto ptrTy = getPtrToElementType(bufferType, lowering);
return LLVMType::getStructTy(ptrTy, int64Ty);
return LLVMType::getStructTy(voidPtrTy, ptrTy, int64Ty);
}
// Range descriptor contains the range bounds and the step as 64-bit integers.
@@ -151,8 +155,9 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) {
return Type();
}
static constexpr int kPtrPosInBuffer = 0;
static constexpr int kSizePosInBuffer = 1;
static constexpr int kBasePtrPosInBuffer = 0;
static constexpr int kPtrPosInBuffer = 1;
static constexpr int kSizePosInBuffer = 2;
static constexpr int kPtrPosInView = 0;
static constexpr int kOffsetPosInView = 1;
static constexpr int kSizePosInView = 2;
@@ -215,13 +220,33 @@ public:
: operands[0];
Value *allocSize =
mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
Value *one = nullptr, *align = nullptr;
if (allocOp.alignment().hasValue()) {
one = constant(int64Ty,
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
align =
constant(int64Ty, rewriter.getIntegerAttr(
rewriter.getIndexType(),
allocOp.alignment().getValue().getSExtValue()));
allocSize = sub(add(allocSize, align), one);
}
Value *allocated =
llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
.getOperation()
->getResult(0);
allocated = bitcast(elementPtrType, allocated);
Value *data = allocated;
if (allocOp.alignment().hasValue()) {
// offset = (align - (ptr % align))% align
Value *offset =
urem(sub(align, urem(ptrtoint(int64Ty, allocated), align)), align);
data = gep(voidPtrTy, allocated, offset);
}
data = bitcast(elementPtrType, data);
Value *desc = undef(bufferDescriptorTy);
desc = insertvalue(bufferDescriptorTy, desc, allocated,
positionAttr(rewriter, kBasePtrPosInBuffer));
desc = insertvalue(bufferDescriptorTy, desc, data,
positionAttr(rewriter, kPtrPosInBuffer));
desc = insertvalue(bufferDescriptorTy, desc, size,
positionAttr(rewriter, kSizePosInBuffer));
@@ -252,18 +277,12 @@ public:
module.push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
auto deallocOp = cast<BufferDeallocOp>(op);
auto elementPtrTy =
getPtrToElementType(deallocOp.getBufferType(), lowering);
// Emit MLIR for buffer_dealloc.
BufferDeallocOpOperandAdaptor adaptor(operands);
edsc::ScopedContext context(rewriter, op->getLoc());
Value *casted =
bitcast(voidPtrTy, extractvalue(elementPtrTy, adaptor.buffer(),
positionAttr(rewriter, 0)));
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
Value *base = extractvalue(voidPtrTy, adaptor.buffer(),
positionAttr(rewriter, kBasePtrPosInBuffer));
llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), base);
rewriter.replaceOp(op, llvm::None);
return matchSuccess();
}

View File

@@ -2,7 +2,6 @@
// -----
// CHECK-LABEL: buffer_alloc_single_index
func @buffer_alloc_single_index() {
// expected-error @+1 {{expected one index operand}}
%0 = linalg.buffer_alloc : !linalg.buffer<?xf32>
@@ -10,7 +9,6 @@ func @buffer_alloc_single_index() {
// -----
// CHECK-LABEL: buffer_alloc_unexpected_index
func @buffer_alloc_unexpected_index(%s : index) {
// expected-error @+1 {{expected zero operand}}
%0 = linalg.buffer_alloc %s : !linalg.buffer<32xf32>
@@ -18,7 +16,6 @@ func @buffer_alloc_unexpected_index(%s : index) {
// -----
// CHECK-LABEL: buffer_alloc_nonegative_size
func @buffer_alloc_nonegative_size() {
// expected-error @+1 {{expected nonnegative static buffer size}}
%0 = linalg.buffer_alloc : !linalg.buffer<0xf32>
@@ -26,7 +23,20 @@ func @buffer_alloc_nonegative_size() {
// -----
// CHECK-LABEL: buffer_valid_element_type
func @buffer_alloc_nonegative_alignment(%arg0: index) {
// expected-error @+1 {{expected positive alignment}}
%0 = linalg.buffer_alloc %arg0 {alignment = -123}: !linalg.buffer<?xf32>
}
// -----
func @buffer_alloc_powerof2_alignment(%arg0: index) {
// expected-error @+1 {{expected power of 2 alignment}}
%0 = linalg.buffer_alloc %arg0 {alignment = 123}: !linalg.buffer<?xf32>
}
// -----
func @buffer_valid_element_type() {
// expected-error @+1 {{expected valid buffer element type}}
%0 = linalg.buffer_alloc : !linalg.buffer<4xindex>
@@ -34,7 +44,6 @@ func @buffer_valid_element_type() {
// -----
// CHECK-LABEL: load_number_of_indices
func @load_number_of_indices(%v : !linalg.view<f32>) {
// expected-error @+2 {{expected 0 indices, got 1}}
%c0 = constant 0 : index
@@ -43,7 +52,6 @@ func @load_number_of_indices(%v : !linalg.view<f32>) {
// -----
// CHECK-LABEL: slice_number_of_indexings
func @slice_number_of_indexings(%arg0: !linalg.view<?x?xf32>) {
// expected-error @+2 {{expected 2 indexings, got 1}}
%c0 = constant 0: index
@@ -52,7 +60,6 @@ func @slice_number_of_indexings(%arg0: !linalg.view<?x?xf32>) {
// -----
// CHECK-LABEL: slice_rank_vs_range_indices
func @slice_rank_vs_range_indices(%arg0: !linalg.view<?x?xf32>) {
// expected-error @+2 {{op expected rank of the view(1) to be the number of ranges(0)}}
%c0 = constant 0: index
@@ -61,7 +68,6 @@ func @slice_rank_vs_range_indices(%arg0: !linalg.view<?x?xf32>) {
// -----
// CHECK-LABEL: store_number_of_indices
func @store_number_of_indices(%v : !linalg.view<f32>) {
// expected-error @+3 {{expected 0 indices, got 1}}
%c0 = constant 0 : index
@@ -71,7 +77,6 @@ func @store_number_of_indices(%v : !linalg.view<f32>) {
// -----
// CHECK-LABEL: subview_number_of_indices
func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
// expected-error @+2 {{expected a view followed by 6 indices specifying a range for each dimension}}
%c0 = constant 0 : index
@@ -80,7 +85,6 @@ func @subview_number_of_indices(%v : !linalg.view<?x?xf32>) {
// -----
// CHECK-LABEL: view_type
func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
// expected-error @+2 {{expected view type}}
%r = linalg.range %min:%max:%step : !linalg.range
@@ -89,7 +93,6 @@ func @view_type(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: in
// -----
// CHECK-LABEL: view_num_ranges
func @view_num_ranges(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %step: index) {
// expected-error @+2 {{expected 2 ranges}}
%r = linalg.range %min:%max:%step : !linalg.range
@@ -98,7 +101,6 @@ func @view_num_ranges(%buf: !linalg.buffer<?xf32>, %min: index, %max: index, %st
// -----
// CHECK-LABEL: yield_parent
func @yield_parent(%arg0: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected 'linalg.generic' parent op}}
linalg.yield %arg0: !linalg.view<?xf32>
@@ -106,7 +108,6 @@ func @yield_parent(%arg0: !linalg.view<?xf32>) {
// -----
// CHECK-LABEL: generic_at_least_2_operands
func @generic_at_least_2_operands(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected 2 or more operands}}
linalg.generic {
@@ -119,7 +120,6 @@ func @generic_at_least_2_operands(%arg0: !linalg.view<f32>) {
// -----
// CHECK-LABEL: generic_exactly_2_views
func @generic_exactly_2_views(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected exactly 2 view operands}}
linalg.generic {
@@ -132,7 +132,6 @@ func @generic_exactly_2_views(%arg0: !linalg.view<f32>) {
// -----
// CHECK-LABEL: generic_undefined_fun
func @generic_undefined_fun(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected fun attribute to refer to a defined symbol}}
linalg.generic {
@@ -147,7 +146,6 @@ func @generic_undefined_fun(%arg0: !linalg.view<f32>) {
func @foo() { return }
// CHECK-LABEL: generic_mismatched_num_arguments
func @generic_mismatched_num_arguments(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected fun arguments to match number of views}}
linalg.generic {
@@ -162,7 +160,6 @@ func @generic_mismatched_num_arguments(%arg0: !linalg.view<f32>) {
func @foo(%0: i32) { return }
// CHECK-LABEL: generic_mismatched_num_returns
func @generic_mismatched_num_returns(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected fun results to match number of output views}}
linalg.generic {
@@ -177,7 +174,6 @@ func @generic_mismatched_num_returns(%arg0: !linalg.view<f32>) {
func @foo(%0: i32) -> i32 { return %0: i32 }
// CHECK-LABEL: generic_symbol_in_map
func @generic_symbol_in_map(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
linalg.generic {
@@ -192,7 +188,6 @@ func @generic_symbol_in_map(%arg0: !linalg.view<f32>) {
func @foo(%0: i32) -> i32 { return %0: i32 }
// CHECK-LABEL: generic_wrong_dim_in_map
func @generic_wrong_dim_in_map(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
@@ -207,7 +202,6 @@ func @generic_wrong_dim_in_map(%arg0: !linalg.view<f32>) {
func @foo(%0: i32) -> i32 { return %0: i32 }
// CHECK-LABEL: generic_zero_d_view
func @generic_zero_d_view(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: '!linalg.view<f32>'}}
linalg.generic {
@@ -222,7 +216,6 @@ func @generic_zero_d_view(%arg0: !linalg.view<f32>) {
func @foo(%0: f32) -> f32 { return %0: f32 }
// CHECK-LABEL: generic_one_d_view
func @generic_one_d_view(%arg0: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected indexing_map #0 results to match view rank: '!linalg.view<?xf32>'}}
linalg.generic {
@@ -240,7 +233,6 @@ func @foo(%0: i32) -> f32 {
return %1: f32
}
// CHECK-LABEL: generic_fun_arg_0_element_type
func @generic_fun_arg_0_element_type(%arg0: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected fun argument 0 to match view element type: 'f32'}}
linalg.generic {
@@ -258,7 +250,6 @@ func @foo(%0: f32) -> i4 {
return %1: i4
}
// CHECK-LABEL: generic_fun_result_0_element_type
func @generic_fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected fun result 0 to match output view element type: 'f32'}}
linalg.generic {
@@ -273,7 +264,6 @@ func @generic_fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
// CHECK-LABEL: generic_singular_maps
func @generic_singular_maps(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>) {
// expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
linalg.generic {
@@ -293,7 +283,6 @@ func @generic_singular_maps(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf3
// -----
// CHECK-LABEL: generic_empty_region
func @generic_empty_region(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected region with 1 block}}
linalg.generic {
@@ -308,7 +297,6 @@ func @generic_empty_region(%arg0: !linalg.view<f32>) {
// -----
// CHECK-LABEL: generic_mismatched_num_arguments
func @generic_mismatched_num_arguments(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected number of block arguments to match number of views}}
linalg.generic {
@@ -322,7 +310,6 @@ func @generic_mismatched_num_arguments(%arg0: !linalg.view<f32>) {
// -----
// CHECK-LABEL: generic_block_arg_type
func @generic_block_arg_type(%arg0: !linalg.view<f32>) {
// expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: '!linalg.view<f32>'}}
linalg.generic {
@@ -336,7 +323,6 @@ func @generic_block_arg_type(%arg0: !linalg.view<f32>) {
// -----
// CHECK-LABEL: generic_fun_result_0_element_type
func @generic_fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
// expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
linalg.generic {

View File

@@ -7,9 +7,28 @@ func @buffer_size(%arg0: !linalg.buffer<?xf32>) {
return
}
// CHECK-LABEL: func @buffer_size
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64 }">
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i8*, float*, i64 }">
// CHECK-NEXT: llvm.add {{.*}}, {{.*}} : !llvm.i64
func @buffer_alloc_aligned(%arg0: index) {
%s = linalg.buffer_alloc %arg0 {alignment=16} : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @buffer_alloc_aligned
// CHECK: %[[c4:.*]] = llvm.constant(4 : index) : !llvm.i64
// CHECK: %[[m:.*]] = llvm.mul %arg0, %[[c4]] : !llvm.i64
// CHECK: %[[c1:.*]] = llvm.constant(1 : index) : !llvm.i64
// CHECK: %[[c16:.*]] = llvm.constant(16 : index) : !llvm.i64
// CHECK: %[[a:.*]] = llvm.add %[[m]], %[[c16]] : !llvm.i64
// CHECK: %[[s:.*]] = llvm.sub %[[a]], %[[c1]] : !llvm.i64
// CHECK: %[[alloc:.*]] = llvm.call @malloc(%[[s]]) : (!llvm.i64) -> !llvm<"i8*">
// aligning `ptr` on `align` is done computing the address `ptr + (align - ptr % align) % align`.
// CHECK: %[[cast:.*]] = llvm.ptrtoint %[[alloc]] : !llvm<"i8*"> to !llvm.i64
// CHECK: %[[rem:.*]] = llvm.urem %[[cast]], %[[c16]] : !llvm.i64
// CHECK: %[[drem:.*]] = llvm.sub %[[c16]], %[[rem]] : !llvm.i64
// CHECK: %[[off:.*]] = llvm.urem %[[drem]], %[[c16]] : !llvm.i64
// CHECK: llvm.getelementptr %{{.*}}[%[[off]]] : (!llvm<"i8*">, !llvm.i64) -> !llvm<"i8*">
func @range(%arg0: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -32,7 +51,7 @@ func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
// CHECK-NEXT: llvm.constant(1 : index) : !llvm.i64
// CHECK-NEXT: llvm.alloca {{.*}} x !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> {alignment = 8 : i64} : (!llvm.i64) -> !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
// CHECK: llvm.load %{{.*}} : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ float*, i64 }">
// CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i8*, float*, i64 }">
// CHECK-NEXT: llvm.bitcast {{.*}} : !llvm<"float*"> to !llvm<"float*">
// CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
// CHECK-NEXT: llvm.constant(0 : index) : !llvm.i64

View File

@@ -8,7 +8,7 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
return
}
// CHECK-LABEL: func @range(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: %{{.*}} = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
func @buffer_size(%arg0: !linalg.buffer<?xf32>) -> index {
%0 = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
@@ -17,20 +17,27 @@ func @buffer_size(%arg0: !linalg.buffer<?xf32>) -> index {
// CHECK-LABEL: func @buffer_size
// CHECK: linalg.buffer_size {{.*}} : !linalg.buffer<?xf32>
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
%2 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %2 : !linalg.buffer<17xvector<4xi8>>
%2 = linalg.buffer_alloc %0 {alignment = 16} : !linalg.buffer<?xvector<4xi8>>
%3 = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
%4 = linalg.buffer_alloc {alignment = 32} : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %4 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %3 : !linalg.buffer<17xvector<4xi8>>
linalg.buffer_dealloc %2 : !linalg.buffer<?xvector<4xi8>>
linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: %{{.*}} = muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: %{{.*}} = linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} {alignment = 16 : i64} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_alloc {alignment = 32 : i64} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<17xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xvector<4xi8>>
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
@@ -52,15 +59,15 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index
return
}
// CHECK-LABEL: func @views(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: %{{.*}} = muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: %{{.*}} = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK-NEXT: %{{.*}} = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: %{{.*}} = linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: %{{.*}} = linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
// CHECK-NEXT: muli %{{.*}}, %{{.*}} : index
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer<?xf32> -> !linalg.view<?x?xvector<4x4xf32>>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<?xf32>, %arg3: !linalg.view<f32>) {
@@ -81,8 +88,8 @@ func @dim(%arg0: !linalg.view<?x?xf32>) {
return
}
// CHECK-LABEL: func @dim(%{{.*}}: !linalg.view<?x?xf32>) {
// CHECK-NEXT: %{{.*}} = linalg.dim %{{.*}}, 1 : !linalg.view<?x?xf32>
// CHECK-NEXT: %{{.*}} = linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: linalg.dim %{{.*}}, 1 : !linalg.view<?x?xf32>
// CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer<?xf32>
// CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer<?xf32>
func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
@@ -99,13 +106,13 @@ func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) {
return
}
// CHECK-LABEL: func @linalg_for(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) {
// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: cmpi "slt", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: cmpi "sge", %{{.*}}, %{{.*}} : index
// CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index
// CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} {
func @fill_view(%arg0: !linalg.view<?xf32>, %arg1: f32) {
linalg.fill(%arg0, %arg1) : !linalg.view<?xf32>, f32
@@ -157,8 +164,8 @@ func @subview(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
return
}
// CHECK-LABEL: func @subview(%{{.*}}: !linalg.view<?x?xvector<3x4xi4>>) {
// CHECK: %{{.*}} = constant 0 : index
// CHECK: %{{.*}} = linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view<?x?xvector<3x4xi4>>
// CHECK: constant 0 : index
// CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view<?x?xvector<3x4xi4>>
func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
%c0 = linalg.buffer_alloc : !linalg.buffer<17xf32>

View File

@@ -7,7 +7,7 @@
func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
%c0 = constant 0 : index
%c1 = constant 1 : index
%buf = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
%buf = linalg.buffer_alloc %s {alignment = 256} : !linalg.buffer<?xf32>
%R = linalg.range %c0:%s:%c1 : !linalg.range
%V = linalg.view %buf[%R] : !linalg.buffer<?xf32> -> !linalg.view<?xf32>
linalg.fill(%V, %f) : !linalg.view<?xf32>, f32