mirror of
https://github.com/intel/llvm.git
synced 2026-01-17 06:40:01 +08:00
[Linalg] Add *Conv1D* matchers (#168050)
-- This commit is the second in the series of adding matchers for linalg.*conv*/*pool*. Refer: https://github.com/llvm/llvm-project/pull/163724 -- In this commit all variants of Conv1D convolution ops have been added. -- For sake of completion for a specific infra required for those ops which don't require dilations/strides information during their creation, this commit also includes a basic Conv2D and Conv3D op as part of the lit test. Signed-off-by: Abhishek Varma <abhvarma@amd.com>
This commit is contained in:
@@ -245,14 +245,22 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
|
||||
SmallVector<Value> inputs = genericOp.getDpsInputs();
|
||||
ValueRange outputs = genericOp.getDpsInits();
|
||||
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
|
||||
SmallVector<Type> resultTypes = genericOp.hasPureTensorSemantics()
|
||||
? TypeRange(ValueRange(outputs))
|
||||
: TypeRange{};
|
||||
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
|
||||
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
|
||||
LinalgOp namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
|
||||
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
|
||||
LinalgOp namedOp;
|
||||
// Ops with no dilations and no strides.
|
||||
if constexpr (std::is_same_v<ConvOpTy, linalg::Conv1DOp> ||
|
||||
std::is_same_v<ConvOpTy, linalg::Conv2DOp> ||
|
||||
std::is_same_v<ConvOpTy, linalg::Conv3DOp>) {
|
||||
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(genericOp, resultTypes,
|
||||
inputs, outputs);
|
||||
} else {
|
||||
Attribute stridesAttr = rewriter.getI64TensorAttr(strides);
|
||||
Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations);
|
||||
namedOp = rewriter.replaceOpWithNewOp<ConvOpTy>(
|
||||
genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr);
|
||||
}
|
||||
return namedOp;
|
||||
}
|
||||
|
||||
@@ -265,9 +273,19 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
|
||||
return specializeToConvOp<ConvOpTy>(rewriter, genericOp, dilations, \
|
||||
strides); \
|
||||
// -----------------------------
|
||||
// Convolution ops.
|
||||
// -----------------------------
|
||||
CONV_OP_SPECIALIZER(linalg::Conv1DOp);
|
||||
CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp);
|
||||
CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp);
|
||||
CONV_OP_SPECIALIZER(linalg::Conv2DOp);
|
||||
CONV_OP_SPECIALIZER(linalg::Conv3DOp);
|
||||
// -----------------------------
|
||||
// Depthwise Convolution ops.
|
||||
// -----------------------------
|
||||
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp);
|
||||
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp);
|
||||
CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp);
|
||||
CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp);
|
||||
CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp);
|
||||
// -----------------------------
|
||||
|
||||
@@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
|
||||
unsigned inputMapIdx = 0, filterMapIdx = 1,
|
||||
outputMapIdx = indexingMaps.size() - 1;
|
||||
AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
|
||||
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
|
||||
auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
|
||||
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
|
||||
return false;
|
||||
|
||||
@@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
|
||||
})));
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(W, w) -> (W + w)>
|
||||
// #filterMap = affine_map<(W, w) -> (w)>
|
||||
// #outputMap = affine_map<(W, w) -> (W)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
|
||||
SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::Conv1DOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(1, 1);
|
||||
*strides = SmallVector<int64_t>(1, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr W = getAffineDimExpr(0, context);
|
||||
AffineExpr w = getAffineDimExpr(1, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
|
||||
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]},
|
||||
/*filterMap=*/{w},
|
||||
/*outputMap=*/{W}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)>
|
||||
// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)>
|
||||
// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
|
||||
LinalgOp op, SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::Conv1DNwcWcfOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(1, 1);
|
||||
*strides = SmallVector<int64_t>(1, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr N = getAffineDimExpr(0, context);
|
||||
AffineExpr W = getAffineDimExpr(1, context);
|
||||
AffineExpr F = getAffineDimExpr(2, context);
|
||||
AffineExpr w = getAffineDimExpr(3, context);
|
||||
AffineExpr c = getAffineDimExpr(4, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
|
||||
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c},
|
||||
/*filterMap=*/{w, c, F},
|
||||
/*outputMap=*/{N, W, F}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)>
|
||||
// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)>
|
||||
// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
|
||||
LinalgOp op, SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::Conv1DNcwFcwOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(1, 1);
|
||||
*strides = SmallVector<int64_t>(1, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr N = getAffineDimExpr(0, context);
|
||||
AffineExpr F = getAffineDimExpr(1, context);
|
||||
AffineExpr W = getAffineDimExpr(2, context);
|
||||
AffineExpr c = getAffineDimExpr(3, context);
|
||||
AffineExpr w = getAffineDimExpr(4, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
|
||||
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]},
|
||||
/*filterMap=*/{F, c, w},
|
||||
/*outputMap=*/{N, F, W}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)>
|
||||
// #filterMap = affine_map<(H, W, h, w) -> (h, w)>
|
||||
// #outputMap = affine_map<(H, W, h, w) -> (H, W)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
|
||||
SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::Conv2DOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(2, 1);
|
||||
*strides = SmallVector<int64_t>(2, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr H = getAffineDimExpr(0, context);
|
||||
AffineExpr W = getAffineDimExpr(1, context);
|
||||
AffineExpr h = getAffineDimExpr(2, context);
|
||||
AffineExpr w = getAffineDimExpr(3, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: H * stride + h * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
|
||||
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
|
||||
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0],
|
||||
W * (*strides)[1] + w * (*dilations)[1]},
|
||||
/*filterMap=*/{h, w},
|
||||
/*outputMap=*/{H, W}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)>
|
||||
// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)>
|
||||
// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
|
||||
SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::Conv3DOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(3, 1);
|
||||
*strides = SmallVector<int64_t>(3, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr D = getAffineDimExpr(0, context);
|
||||
AffineExpr H = getAffineDimExpr(1, context);
|
||||
AffineExpr W = getAffineDimExpr(2, context);
|
||||
AffineExpr d = getAffineDimExpr(3, context);
|
||||
AffineExpr h = getAffineDimExpr(4, context);
|
||||
AffineExpr w = getAffineDimExpr(5, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: D * stride + d * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0,
|
||||
/*oDim=*/0, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match: H * stride + h * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1,
|
||||
/*oDim=*/1, (*dilations)[1], (*strides)[1]))
|
||||
return false;
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2,
|
||||
/*oDim=*/2, (*dilations)[2], (*strides)[2]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0],
|
||||
H * (*strides)[1] + h * (*dilations)[1],
|
||||
W * (*strides)[2] + w * (*dilations)[2]},
|
||||
/*filterMap=*/{d, h, w},
|
||||
/*outputMap=*/{D, H, W}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)>
|
||||
// #filterMap = affine_map<(N, W, C, w) -> (C, w)>
|
||||
// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
|
||||
LinalgOp op, SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(1, 1);
|
||||
*strides = SmallVector<int64_t>(1, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr N = getAffineDimExpr(0, context);
|
||||
AffineExpr W = getAffineDimExpr(1, context);
|
||||
AffineExpr C = getAffineDimExpr(2, context);
|
||||
AffineExpr w = getAffineDimExpr(3, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
|
||||
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]},
|
||||
/*filterMap=*/{C, w},
|
||||
/*outputMap=*/{N, C, W}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
|
||||
// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
|
||||
// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
|
||||
@@ -474,6 +731,47 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)>
|
||||
// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)>
|
||||
// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)>
|
||||
template <>
|
||||
bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
|
||||
LinalgOp op, SmallVector<int64_t> *dilations,
|
||||
SmallVector<int64_t> *strides) {
|
||||
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
|
||||
return true;
|
||||
|
||||
assert(isaConvolutionOpInterface(op) &&
|
||||
"expected op to implement ConvolutionOpInterface");
|
||||
|
||||
*dilations = SmallVector<int64_t>(1, 1);
|
||||
*strides = SmallVector<int64_t>(1, 1);
|
||||
MLIRContext *context = op->getContext();
|
||||
AffineExpr N = getAffineDimExpr(0, context);
|
||||
AffineExpr W = getAffineDimExpr(1, context);
|
||||
AffineExpr C = getAffineDimExpr(2, context);
|
||||
AffineExpr CM = getAffineDimExpr(3, context);
|
||||
AffineExpr w = getAffineDimExpr(4, context);
|
||||
ArrayAttr indexingMaps = op.getIndexingMaps();
|
||||
// First fetch dilations/strides :-
|
||||
// Match: W * stride + w * dilation
|
||||
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
|
||||
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
|
||||
return false;
|
||||
// Match expected indexing maps
|
||||
if (!convLayoutMatches(
|
||||
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
|
||||
/*filterMap=*/{w, C, CM},
|
||||
/*outputMap=*/{N, W, C, CM}},
|
||||
indexingMaps, context))
|
||||
return false;
|
||||
// Match body
|
||||
Block *body = op.getBlock();
|
||||
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
|
||||
Value yieldVal = yieldOp.getOperand(0);
|
||||
return bodyMatcherForConvolutionOps(yieldVal, body);
|
||||
}
|
||||
|
||||
// #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)>
|
||||
// #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)>
|
||||
// #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)>
|
||||
|
||||
@@ -1,8 +1,87 @@
|
||||
// The following test examples of linalg convolution named ops lowered to linalg.generic and then
|
||||
// lifted back up to named op.
|
||||
// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test.
|
||||
|
||||
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic
|
||||
|
||||
// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test.
|
||||
// -----------------------------
|
||||
// Convolution ops.
|
||||
// -----------------------------
|
||||
func.func @conv_1d(%in : tensor<?xf32>, %filter : tensor<?xf32>, %out : tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = linalg.conv_1d
|
||||
ins(%in, %filter : tensor<?xf32>, tensor<?xf32>)
|
||||
outs(%out : tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: @conv_1d
|
||||
// CHECK: linalg.conv_1d
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_1d_nwc_wcf(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 = linalg.conv_1d_nwc_wcf
|
||||
{dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
|
||||
ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: @conv_1d_nwc_wcf
|
||||
// CHECK: linalg.conv_1d_nwc_wcf
|
||||
// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_1d_ncw_fcw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 = linalg.conv_1d_ncw_fcw
|
||||
{dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
|
||||
ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: @conv_1d_ncw_fcw
|
||||
// CHECK: linalg.conv_1d_ncw_fcw
|
||||
// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_2d(%in : tensor<?x?xf32>, %filter : tensor<?x?xf32>, %out : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%0 = linalg.conv_2d
|
||||
ins(%in, %filter : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||
outs(%out: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
// CHECK: @conv_2d
|
||||
// CHECK: linalg.conv_2d
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv_3d(%in : tensor<?x?x?xf32>, %filter : tensor<?x?x?xf32>, %out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 = linalg.conv_3d
|
||||
ins(%in, %filter : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs(%out : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: @conv_3d
|
||||
// CHECK: linalg.conv_3d
|
||||
|
||||
// -----
|
||||
|
||||
// -----------------------------
|
||||
// Depthwise Convolution ops.
|
||||
// -----------------------------
|
||||
func.func @depthwise_conv_1d_ncw_cw(%input: tensor<?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
|
||||
%0 = linalg.depthwise_conv_1d_ncw_cw
|
||||
{dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
|
||||
ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?xf32>)
|
||||
outs (%output: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK: @depthwise_conv_1d_ncw_cw
|
||||
// CHECK: linalg.depthwise_conv_1d_ncw_cw
|
||||
// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> {
|
||||
%0 = linalg.depthwise_conv_1d_nwc_wc
|
||||
{dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}
|
||||
@@ -16,6 +95,19 @@ func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: t
|
||||
|
||||
// -----
|
||||
|
||||
func.func @depthwise_conv_1d_nwc_wcm(%input: tensor<?x?x?xf32>, %filter: tensor<?x?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%0 = linalg.depthwise_conv_1d_nwc_wcm
|
||||
{dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
|
||||
ins (%input, %filter: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs (%output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
// CHECK: @depthwise_conv_1d_nwc_wcm
|
||||
// CHECK: linalg.depthwise_conv_1d_nwc_wcm
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>
|
||||
|
||||
// -----
|
||||
|
||||
func.func @depthwise_conv_2d_nchw_chw(%input: tensor<?x?x?x?xf16>, %filter: tensor<?x?x?xf16>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%0 = linalg.depthwise_conv_2d_nchw_chw
|
||||
{dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>}
|
||||
@@ -42,6 +134,9 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor<?x?x?x?x?xf32>, %filter:
|
||||
|
||||
// -----
|
||||
|
||||
// -----------------------------
|
||||
// Pooling ops.
|
||||
// -----------------------------
|
||||
func.func @pooling_nhwc_max(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?xf32>, %output: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
|
||||
%0 = linalg.pooling_nhwc_max
|
||||
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
|
||||
Reference in New Issue
Block a user