mirror of
https://github.com/intel/llvm.git
synced 2026-01-20 10:58:11 +08:00
[mlir] Simplify DestinationStyleOpInterface.
Differential Revision: https://reviews.llvm.org/D135348
This commit is contained in:
@@ -317,7 +317,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!$_op.isOutputTensor(opOperand))
|
||||
if (!$_op.isOutput(opOperand))
|
||||
return false;
|
||||
return payloadUsesValueFromOperand(opOperand);
|
||||
}]
|
||||
@@ -606,7 +606,13 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getInputAndOutputOperands();
|
||||
OpOperandVector result;
|
||||
result.reserve($_op->getNumOperands());
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
@@ -684,13 +690,8 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<int64_t> res;
|
||||
// MLIR currently does not support dependent interfaces or interface
|
||||
// inheritance. By construction all ops with StructuredOpInterface must
|
||||
// implement DestinationStyleOpInterface.
|
||||
// TODO: reevalute the need for a cast when a better mechanism exists.
|
||||
auto iface = cast<DestinationStyleOpInterface>(*this->getOperation());
|
||||
for (OpOperand *opOperand : iface.getInputAndOutputOperands())
|
||||
llvm::append_range(res, getShape(opOperand));
|
||||
for (OpOperand &opOperand : this->getOperation()->getOpOperands())
|
||||
llvm::append_range(res, getShape(&opOperand));
|
||||
return res;
|
||||
}]
|
||||
>,
|
||||
@@ -779,31 +780,16 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
// TODO: reevalute the need for a cast when a better mechanism exists.
|
||||
//========================================================================//
|
||||
|
||||
ValueRange getInputs() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getInputs();
|
||||
}
|
||||
|
||||
int64_t getNumInputs() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getNumInputs();
|
||||
}
|
||||
|
||||
ValueRange getOutputs() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getOutputs();
|
||||
}
|
||||
|
||||
int64_t getNumOutputs() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getNumOutputs();
|
||||
}
|
||||
|
||||
int64_t getNumInputsAndOutputs() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getNumInputsAndOutputs();
|
||||
}
|
||||
|
||||
OpOperandVector getInputOperands() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getInputOperands();
|
||||
@@ -814,14 +800,9 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
.getInputOperand(i);
|
||||
}
|
||||
|
||||
OpOperandVector getInputBufferOperands() {
|
||||
void setOutputOperand(int64_t i, Value value) {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getInputBufferOperands();
|
||||
}
|
||||
|
||||
OpOperandVector getInputTensorOperands() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getInputTensorOperands();
|
||||
.setOutputOperand(i, value);
|
||||
}
|
||||
|
||||
OpOperandVector getOutputOperands() {
|
||||
@@ -834,44 +815,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
.getOutputOperand(i);
|
||||
}
|
||||
|
||||
void setOutputOperand(int64_t i, Value value) {
|
||||
bool isInput(OpOperand *opOperand) {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.setOutputOperand(i, value);
|
||||
.isInput(opOperand);
|
||||
}
|
||||
|
||||
OpOperandVector getOutputBufferOperands() {
|
||||
bool isOutput(OpOperand *opOperand) {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getOutputBufferOperands();
|
||||
}
|
||||
|
||||
OpOperandVector getOutputTensorOperands() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getOutputTensorOperands();
|
||||
}
|
||||
|
||||
SmallVector<MemRefType> getOutputBufferTypes() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getOutputBufferTypes();
|
||||
}
|
||||
|
||||
SmallVector<RankedTensorType> getOutputTensorTypes() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getOutputTensorTypes();
|
||||
}
|
||||
|
||||
OpOperandVector getInputAndOutputOperands() {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.getInputAndOutputOperands();
|
||||
}
|
||||
|
||||
bool isInputTensor(OpOperand *opOperand) {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.isInputTensor(opOperand);
|
||||
}
|
||||
|
||||
bool isOutputTensor(OpOperand *opOperand) {
|
||||
return cast<DestinationStyleOpInterface>(*this->getOperation())
|
||||
.isOutputTensor(opOperand);
|
||||
.isOutput(opOperand);
|
||||
}
|
||||
|
||||
bool isScalar(OpOperand *opOperand) {
|
||||
@@ -928,331 +879,185 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
|
||||
let verifyWithRegions = 1;
|
||||
}
|
||||
|
||||
// The 'DestinationStyleOpInterface' provides access to the methods relevant
|
||||
// for destination-style ops. A destination-style operation has 'n' input
|
||||
// arguments and 'm' output arguments. Each op that wants to implement
|
||||
// DestinationStyleOpInterface needs to define getInputs() and getOutputs()
|
||||
// methods.
|
||||
// Ops that are in destination style have designated output operands, which act
|
||||
// as initial tensor values for the results of the operation or the output
|
||||
// buffers to which the results of the op will be written.
|
||||
//
|
||||
// Output operands must be tensors or memrefs. Input operands can have any
|
||||
// type. All non-output operands are inputs.
|
||||
|
||||
// It is assumed that the output operands of the op are the operands at
|
||||
// position [start, end). The positions are defined by getOutputsPositionRange
|
||||
// method. All non-output operands are "inputs" of the DPS op.
|
||||
|
||||
// If the op has "tensor semantics", then the input operands are either scalars
|
||||
// or tensors. The output operands are tensors and every tensor output is tied
|
||||
// to a corresponding tensor OpResult in a 1-to-1 fashion. The i-th output
|
||||
// tensor is tied to the i-th OpResult. The op may not have any additional
|
||||
// OpResults. Output operands and their tied OpResults have the same type.
|
||||
//
|
||||
// If the op has "buffer semantics", then the input operands are either memrefs
|
||||
// or other non-tensor types, e.g. scalar types. Furthermore, the output
|
||||
// operands are memrefs and the op has no results.
|
||||
//
|
||||
// Destination-passing style abstraction makes certain transformations easier.
|
||||
// For example, tiling implementation can extract/insert slices from/into the
|
||||
// destination of an op and use the resulting shaped value as an iter_arg in
|
||||
// the surrounding loop structure. As another example, bufferization does not
|
||||
// have to allocate new buffers for destinations (in case of in-place
|
||||
// bufferization) and can directly reuse the existing destination buffer.
|
||||
//
|
||||
// Example of a destination style op: `%r = tensor.insert_slice %t into %d`,
|
||||
// where `%t` is the single input and `%d` is the single output. `%d` is tied
|
||||
// to `%r`.
|
||||
//
|
||||
// Example of an op that is not in destination style: `%r = tensor.pad %t`.
|
||||
// This op is not in destination style because `%r` and `%t` have different
|
||||
// shape.
|
||||
//
|
||||
// Each op that wants to implement DestinationStyleOpInterface needs to define
|
||||
// the getOutputsPositionRange() method.
|
||||
def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
|
||||
let cppNamespace = "::mlir::linalg";
|
||||
let methods = [
|
||||
//===------------------------------------------------------------------===//
|
||||
// Num input/output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
// `getInputs` must be defined by each op that wants to implement the
|
||||
// DestinationStyleOpInterface.
|
||||
// This method has to be defined for every DPS op.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input shape operands.
|
||||
}],
|
||||
/*retTy=*/"ValueRange",
|
||||
/*methodName=*/"getInputs",
|
||||
/*args=*/(ins)
|
||||
>,
|
||||
// These special methods rely on `getInputs` and `getOutputs` being defined
|
||||
// by each op that wants to implement the DestinationStyleOpInterface.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of inputs.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumInputs",
|
||||
/*desc=*/"Return start and end indices of the output operands range.",
|
||||
/*retTy=*/"std::pair<int64_t, int64_t>",
|
||||
/*methodName=*/"getOutputsPositionRange",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getInputs().size();
|
||||
}]
|
||||
/*defaultImplementation=*/""
|
||||
>,
|
||||
// `getOutputs` must be defined by each op that wants to implement the
|
||||
// DestinationStyleOpInterface.
|
||||
//===------------------------------------------------------------------===//
|
||||
// Operands handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
// The operand list is assumed to start with the input operands and end
|
||||
// with the output operands. Therefore, all methods to access the inputs
|
||||
// and outputs can be expressed if the number of output operands is know.
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output shape operands.
|
||||
}],
|
||||
/*retTy=*/"ValueRange",
|
||||
/*methodName=*/"getOutputs",
|
||||
/*args=*/(ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of outputs.
|
||||
}],
|
||||
/*desc=*/"Return the number of outputs.",
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumOutputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getOutputs().size();
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
return end - start;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the number of inputs and outputs.
|
||||
}],
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumInputsAndOutputs",
|
||||
/*desc=*/"Return the output operands.",
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return this->getOperation()->getNumOperands();
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
|
||||
OpOperandVector result;
|
||||
result.reserve(end - start);
|
||||
for (int i = start; i < end; ++i)
|
||||
result.push_back(&$_op->getOpOperand(i));
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input operands handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the input operands.
|
||||
}],
|
||||
/*desc=*/"Return the `i`-th output operand.",
|
||||
/*retTy=*/"OpOperand*",
|
||||
/*methodName=*/"getOutputOperand",
|
||||
/*args=*/(ins "int64_t":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < $_op.getNumOutputs());
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
return &$_op->getOpOperand(start + i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Set the `i`-th output operand.",
|
||||
/*retTy=*/"void",
|
||||
/*methodName=*/"setOutputOperand",
|
||||
/*args=*/(ins "int64_t":$i, "Value":$value),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < $_op.getNumOutputs());
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
$_op->setOperand(start + i, value);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the number of inputs.",
|
||||
/*retTy=*/"int64_t",
|
||||
/*methodName=*/"getNumInputs",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return $_op.getNumOperands() - $_op.getNumOutputs();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Return the input operands.",
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numInputs = getNumInputs();
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
int64_t numOutputs = end - start;
|
||||
int64_t numOperands = $_op.getNumOperands();
|
||||
|
||||
OpOperandVector result;
|
||||
result.reserve(numInputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands().take_front(numInputs),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
result.reserve(numOperands - numOutputs);
|
||||
for (int i = 0; i < start; ++i)
|
||||
result.push_back(&$_op->getOpOperand(i));
|
||||
for (int i = end; i < numOperands; ++i)
|
||||
result.push_back(&$_op->getOpOperand(end + i));
|
||||
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `i`-th input operand.
|
||||
}],
|
||||
/*desc=*/[{ Return the `i`-th input operand. }],
|
||||
/*retTy=*/"OpOperand*",
|
||||
/*methodName=*/"getInputOperand",
|
||||
/*args=*/(ins "int64_t":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < getNumInputs());
|
||||
return &this->getOperation()->getOpOperand(i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of input operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputBufferOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumInputs());
|
||||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of input operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputTensorOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumInputs());
|
||||
llvm::copy_if(getInputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Output operands handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the output operands.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numOutputs = getNumOutputs();
|
||||
OpOperandVector result;
|
||||
result.reserve(numOutputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands()
|
||||
.take_back(numOutputs),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the `i`-th output operand.
|
||||
}],
|
||||
/*retTy=*/"OpOperand*",
|
||||
/*methodName=*/"getOutputOperand",
|
||||
/*args=*/(ins "int64_t":$i),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < getNumOutputs());
|
||||
return &this->getOperation()->getOpOperand(getNumInputs() + i);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Set the `i`-th output operand.
|
||||
}],
|
||||
/*retTy=*/"void",
|
||||
/*methodName=*/"setOutputOperand",
|
||||
/*args=*/(ins "int64_t":$i, "Value":$value),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(i >= 0 && i < getNumOutputs());
|
||||
this->getOperation()->setOperand(getNumInputs() + i, value);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of output operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputBufferOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the subset of output operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getOutputTensorOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
OpOperandVector result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::copy_if(getOutputOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperand) {
|
||||
return opOperand->get().getType().template isa<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the types of the subset of output operands that are of buffer type.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<MemRefType>",
|
||||
/*methodName=*/"getOutputBufferTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<MemRefType> result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::transform(getOutputBufferOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<MemRefType>();
|
||||
});
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the types of the subset of output operands that are of tensor type.
|
||||
}],
|
||||
/*retTy=*/"SmallVector<RankedTensorType>",
|
||||
/*methodName=*/"getOutputTensorTypes",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
SmallVector<RankedTensorType> result;
|
||||
result.reserve(getNumOutputs());
|
||||
llvm::transform(getOutputTensorOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand *opOperands) {
|
||||
return opOperands->get().getType().cast<RankedTensorType>();
|
||||
});
|
||||
return result;
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
return &$_op->getOpOperand(i < start ? i : i + end - start) ;
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Input and Output arguments handling.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the range over input and output operands.
|
||||
}],
|
||||
/*retTy=*/"OpOperandVector",
|
||||
/*methodName=*/"getInputAndOutputOperands",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
int64_t numInputsAndOutputs = getNumInputsAndOutputs();
|
||||
OpOperandVector result;
|
||||
result.reserve(numInputsAndOutputs);
|
||||
llvm::transform(
|
||||
this->getOperation()->getOpOperands(),
|
||||
std::back_inserter(result),
|
||||
[](OpOperand &opOperand) { return &opOperand; });
|
||||
return result;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if `opOperand` is an input tensor.
|
||||
}],
|
||||
/*desc=*/"Return true if `opOperand` is an input.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isInputTensor",
|
||||
/*methodName=*/"isInput",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() < $_op.getNumInputs())
|
||||
return true;
|
||||
return false;
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
auto operandNumber = opOperand->getOperandNumber();
|
||||
return operandNumber < start || operandNumber >= end;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if `opOperand` is an output tensor.
|
||||
}],
|
||||
/*desc=*/"Return true if `opOperand` is an output.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isOutputTensor",
|
||||
/*methodName=*/"isOutput",
|
||||
/*args=*/(ins "OpOperand *":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
if (!opOperand->get().getType().template isa<RankedTensorType>())
|
||||
return false;
|
||||
if (opOperand->getOperandNumber() >= $_op.getNumInputs())
|
||||
return true;
|
||||
return false;
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
auto operandNumber = opOperand->getOperandNumber();
|
||||
return operandNumber >= start && operandNumber < end;
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return true if the `opOperand` is a scalar value.
|
||||
}],
|
||||
/*desc=*/"Return true if the `opOperand` is a scalar value.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isScalar",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
@@ -1263,35 +1068,33 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return the result tied to `opOperand`.
|
||||
}],
|
||||
/*desc=*/"Return the result tied to `opOperand`.",
|
||||
/*retTy=*/"OpResult",
|
||||
/*methodName=*/"getTiedOpResult",
|
||||
/*args=*/(ins "OpOperand*":$opOperand),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
assert(opOperand->getOwner() == this->getOperation());
|
||||
int64_t resultIndex = opOperand->getOperandNumber() - getNumInputs();
|
||||
|
||||
auto [start, end] = $_op.getOutputsPositionRange();
|
||||
int64_t resultIndex = opOperand->getOperandNumber() - start;
|
||||
assert(resultIndex >= 0 &&
|
||||
resultIndex < this->getOperation()->getNumResults() );
|
||||
return this->getOperation()->getResult(resultIndex);
|
||||
resultIndex < $_op->getNumResults() );
|
||||
return $_op->getResult(resultIndex);
|
||||
}]
|
||||
>,
|
||||
//===------------------------------------------------------------------===//
|
||||
// Other interface methods.
|
||||
//===------------------------------------------------------------------===//
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only MemRef input and outputs.
|
||||
}],
|
||||
/*desc=*/"Return whether the op has only MemRef input and outputs.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasBufferSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return this->getOperation()->getNumResults() == 0 &&
|
||||
llvm::all_of(this->getOperation()->getOpOperands(),
|
||||
return $_op->getNumResults() == 0 &&
|
||||
llvm::all_of($_op->getOpOperands(),
|
||||
[&](OpOperand &opOperand) {
|
||||
return isScalar(&opOperand) ||
|
||||
opOperand.get().getType().template isa<MemRefType>();
|
||||
@@ -1299,15 +1102,13 @@ def DestinationStyleOpInterface : OpInterface<"DestinationStyleOpInterface"> {
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/[{
|
||||
Return whether the op has only RankedTensor input and outputs.
|
||||
}],
|
||||
/*desc=*/"Return whether the op has only RankedTensor input and outputs.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"hasTensorSemantics",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
return llvm::all_of(this->getOperation()->getOpOperands(),
|
||||
return llvm::all_of($_op->getOpOperands(),
|
||||
[&](OpOperand &opOperand) {
|
||||
return isScalar(&opOperand) ||
|
||||
opOperand.get().getType().template isa<RankedTensorType>();
|
||||
|
||||
@@ -215,6 +215,10 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
|
||||
getRegionBuilder() {
|
||||
return nullptr;
|
||||
}
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {getNumOperands - getOutputs().size(), getNumOperands};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
@@ -271,11 +275,10 @@ def MapOp : LinalgStructuredBase_Op<"map", [
|
||||
}
|
||||
|
||||
// Implement functions necessary for DestinationStyleOpInterface.
|
||||
unsigned getNumInputs() {
|
||||
return this->getOperation()->getNumOperands() - getNumOutputs();
|
||||
};
|
||||
unsigned getNumOutputs() { return 1; };
|
||||
mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {getNumOperands - 1, getNumOperands};
|
||||
}
|
||||
linalg::OpOperandVector getOpOperandsMatchingBBargs() {
|
||||
return getInputOperands();
|
||||
}
|
||||
@@ -341,14 +344,14 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
|
||||
}
|
||||
|
||||
// Implement functions necessary for DestinationStyleOpInterface.
|
||||
mlir::ValueRange getOutputs() { return getInits(); }
|
||||
unsigned getNumInputs() { return getInputs().size(); };
|
||||
unsigned getNumOutputs() { return getInits().size(); };
|
||||
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
|
||||
mlir::ArrayRef<mlir::NamedAttribute>)>
|
||||
getRegionBuilder() {
|
||||
return nullptr;
|
||||
}
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
return {getInits().size(), getNumOperands()};
|
||||
}
|
||||
}];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -29,9 +29,9 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
|
||||
|
||||
SmallVector<Type, 8> argTypes;
|
||||
SmallVector<Location, 8> argLocs;
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType()));
|
||||
argLocs.push_back(opOperand->get().getLoc());
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType()));
|
||||
argLocs.push_back(opOperand.get().getLoc());
|
||||
}
|
||||
|
||||
ImplicitLocOpBuilder b(op->getLoc(), op->getContext());
|
||||
|
||||
@@ -166,6 +166,8 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
||||
<< " and " << *dst.getOperation() << "\n");
|
||||
if (src.hasTensorSemantics() && dst.hasTensorSemantics()) {
|
||||
for (OpOperand *dstOpOperand : dst.getInputOperands()) {
|
||||
if (!dstOpOperand->get().getType().isa<RankedTensorType>())
|
||||
continue;
|
||||
// Check if the operand is defined by the src.
|
||||
auto definingOp = dstOpOperand->get().getDefiningOp<LinalgOp>();
|
||||
if (definingOp && definingOp == src)
|
||||
@@ -188,23 +190,31 @@ void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) {
|
||||
}
|
||||
assert(src.hasBufferSemantics() && dst.hasBufferSemantics() &&
|
||||
"unhandled dependence tracking for mixed buffer/tensor operations");
|
||||
for (OpOperand *srcOpOperand : src.getOutputBufferOperands()) { // W
|
||||
for (OpOperand *srcOpOperand : src.getOutputOperands()) { // W
|
||||
// RAW graph
|
||||
for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
|
||||
for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
|
||||
if (!dstOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAW alias
|
||||
addDependenceElem(DependenceType::RAW, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
// WAW graph
|
||||
for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
|
||||
for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAW alias
|
||||
addDependenceElem(DependenceType::WAW, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
for (OpOperand *srcOpOperand : src.getInputBufferOperands()) { // R
|
||||
for (OpOperand *srcOpOperand : src.getInputOperands()) { // R
|
||||
if (!srcOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
// RAR graph
|
||||
for (OpOperand *dstOpOperand : dst.getInputBufferOperands()) // R
|
||||
for (OpOperand *dstOpOperand : dst.getInputOperands()) { // R
|
||||
if (!dstOpOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // RAR alias
|
||||
addDependenceElem(DependenceType::RAR, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
// WAR graph
|
||||
for (OpOperand *dstOpOperand : dst.getOutputBufferOperands()) // W
|
||||
for (OpOperand *dstOpOperand : dst.getOutputOperands()) // W
|
||||
if (aliases.alias(srcOpOperand->get(), dstOpOperand->get())) // WAR alias
|
||||
addDependenceElem(DependenceType::WAR, srcOpOperand, dstOpOperand);
|
||||
}
|
||||
|
||||
@@ -31,10 +31,10 @@ using namespace mlir::linalg;
|
||||
bool linalg::detail::canOpOperandsBeDroppedImpl(
|
||||
linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (auto *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
if (llvm::is_contained(droppedOperands, opOperand))
|
||||
for (auto &opOperand : linalgOp->getOpOperands()) {
|
||||
if (llvm::is_contained(droppedOperands, &opOperand))
|
||||
continue;
|
||||
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(opOperand));
|
||||
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
|
||||
}
|
||||
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
|
||||
}
|
||||
@@ -491,9 +491,9 @@ static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source,
|
||||
SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
||||
Location loc) {
|
||||
SmallVector<OpFoldResult> res;
|
||||
for (OpOperand *opOperand : getInputAndOutputOperands()) {
|
||||
for (int64_t i = 0, e = getRank(opOperand); i < e; ++i)
|
||||
res.push_back(createFoldedDimOp(b, loc, opOperand->get(), i));
|
||||
for (OpOperand &opOperand : getOperation()->getOpOperands()) {
|
||||
for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
|
||||
res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -501,8 +501,8 @@ SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
|
||||
SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
|
||||
SmallVector<int64_t, 4> res;
|
||||
assert(!hasDynamicShape() && "expected operands to have static shapes");
|
||||
for (OpOperand *opOperand : getInputAndOutputOperands())
|
||||
llvm::append_range(res, getShape(opOperand));
|
||||
for (OpOperand &opOperand : getOperation()->getOpOperands())
|
||||
llvm::append_range(res, getShape(&opOperand));
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -644,32 +644,32 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
|
||||
// All input/output operands must be indexed.
|
||||
if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
|
||||
linalgOp.getNumInputsAndOutputs())
|
||||
linalgOp->getNumOperands())
|
||||
return op->emitOpError("expected the number of indexing_map (")
|
||||
<< linalgOp.getIndexingMapsArray().size()
|
||||
<< ") to be equal to the number of input/output operands ("
|
||||
<< linalgOp.getNumInputsAndOutputs() << ")";
|
||||
<< linalgOp->getNumOperands() << ")";
|
||||
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
|
||||
// Symbols disallowed.
|
||||
if (indexingMap.getNumSymbols() != 0)
|
||||
return op->emitOpError("unexpected symbols in indexing_map #")
|
||||
<< opOperand->getOperandNumber();
|
||||
<< opOperand.getOperandNumber();
|
||||
|
||||
// Domain must be consistent.
|
||||
unsigned numLoops = linalgOp.getNumLoops();
|
||||
if (indexingMap.getNumDims() != numLoops)
|
||||
return op->emitOpError("expected indexing_map #")
|
||||
<< opOperand->getOperandNumber() << " to have " << numLoops
|
||||
<< opOperand.getOperandNumber() << " to have " << numLoops
|
||||
<< " dim(s) to match the number of loops";
|
||||
|
||||
int64_t rank = linalgOp.getRank(opOperand);
|
||||
int64_t rank = linalgOp.getRank(&opOperand);
|
||||
if (indexingMap.getNumResults() != rank)
|
||||
return op->emitOpError("expected operand rank (")
|
||||
<< rank << ") to match the result rank of indexing_map #"
|
||||
<< opOperand->getOperandNumber() << " ("
|
||||
<< opOperand.getOperandNumber() << " ("
|
||||
<< indexingMap.getNumResults() << ")";
|
||||
}
|
||||
|
||||
@@ -688,13 +688,13 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
|
||||
for (int64_t &range : endLoopRangeValues)
|
||||
range -= 1;
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
SmallVector<int64_t, 4> startIndices =
|
||||
indexingMap.compose(startLoopRangeValues);
|
||||
SmallVector<int64_t, 4> endIndices =
|
||||
indexingMap.compose(endLoopRangeValues);
|
||||
ArrayRef<int64_t> shape = linalgOp.getShape(opOperand);
|
||||
ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
|
||||
for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
|
||||
// Ignore dynamic dimension or the case that the dimension size is 0
|
||||
if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
|
||||
@@ -725,17 +725,16 @@ LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
|
||||
if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
|
||||
if (inferredDimSize != shape[dim]) {
|
||||
return op->emitOpError("inferred input/output operand #")
|
||||
<< opOperand->getOperandNumber()
|
||||
<< " has shape's dimension #" << dim << " to be "
|
||||
<< inferredDimSize << ", but found " << shape[dim];
|
||||
<< opOperand.getOperandNumber() << " has shape's dimension #"
|
||||
<< dim << " to be " << inferredDimSize << ", but found "
|
||||
<< shape[dim];
|
||||
}
|
||||
} else {
|
||||
if (inferredDimSize > shape[dim]) {
|
||||
return op->emitOpError("inferred input/output operand #")
|
||||
<< opOperand->getOperandNumber()
|
||||
<< " has shape's dimension #" << dim
|
||||
<< " to be greater than or equal to " << inferredDimSize
|
||||
<< ", but found " << shape[dim];
|
||||
<< opOperand.getOperandNumber() << " has shape's dimension #"
|
||||
<< dim << " to be greater than or equal to "
|
||||
<< inferredDimSize << ", but found " << shape[dim];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -777,6 +776,15 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
|
||||
DestinationStyleOpInterface dstStyleOp =
|
||||
cast<DestinationStyleOpInterface>(op);
|
||||
|
||||
SmallVector<OpOperand *> outputBufferOperands, outputTensorOperands;
|
||||
for (OpOperand *operand : dstStyleOp.getOutputOperands()) {
|
||||
Type type = operand->get().getType();
|
||||
if (type.isa<MemRefType>())
|
||||
outputBufferOperands.push_back(operand);
|
||||
if (type.isa<RankedTensorType>())
|
||||
outputTensorOperands.push_back(operand);
|
||||
}
|
||||
|
||||
// Expect at least one output operand.
|
||||
// This means an op that constructs a tensor out of indices cannot be a
|
||||
// LinalgOp at the moment. For now this will have to be a special op until we
|
||||
@@ -788,23 +796,22 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
|
||||
if (failed(OpTrait::impl::verifyNOperands(op, numInputs + numOutputs)))
|
||||
return failure();
|
||||
// Verify the number of results matches the number of output tensors.
|
||||
if (op->getNumResults() != dstStyleOp.getOutputTensorOperands().size())
|
||||
if (op->getNumResults() != outputTensorOperands.size())
|
||||
return op->emitOpError("expected the number of results (")
|
||||
<< op->getNumResults()
|
||||
<< ") to be equal to the number of output tensors ("
|
||||
<< dstStyleOp.getOutputTensorOperands().size() << ")";
|
||||
<< outputTensorOperands.size() << ")";
|
||||
|
||||
// Simplifying assumption: either full tensor or full buffer mode.
|
||||
// This allows simpler verification of output operands vs result types
|
||||
// without premature tracking of which operand is what in mixed-mode.
|
||||
// TODO: relax when mixed-mode needs to pass verification.
|
||||
if (!dstStyleOp.getOutputBufferOperands().empty() &&
|
||||
!dstStyleOp.getOutputTensorOperands().empty())
|
||||
if (!outputBufferOperands.empty() && !outputTensorOperands.empty())
|
||||
return op->emitOpError(
|
||||
"expected output operands to all have tensor type or "
|
||||
"all have buffer type");
|
||||
|
||||
for (OpOperand *opOperand : dstStyleOp.getOutputTensorOperands()) {
|
||||
for (OpOperand *opOperand : outputTensorOperands) {
|
||||
OpResult result = dstStyleOp.getTiedOpResult(opOperand);
|
||||
if (result.getType() != opOperand->get().getType())
|
||||
return op->emitOpError("expected type of operand #")
|
||||
@@ -813,6 +820,5 @@ mlir::linalg::detail::verifyDestinationStyleOpInterface(Operation *op) {
|
||||
<< " to match type of corresponding result (" << result.getType()
|
||||
<< ")";
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -767,7 +767,8 @@ void GenericOp::print(OpAsmPrinter &p) {
|
||||
}
|
||||
|
||||
// Printing is shared with named ops, except for the region and attributes
|
||||
printCommonStructuredOpParts(p, getInputs(), getOutputs());
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
|
||||
genericAttrNames.push_back("operand_segment_sizes");
|
||||
genericAttrNamesSet.insert(genericAttrNames.back());
|
||||
@@ -835,15 +836,20 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
static void getGenericEffectsImpl(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects,
|
||||
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
|
||||
for (Value value : inputBuffers) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||
ValueRange results, OpOperandVector inputOperands,
|
||||
OpOperandVector outputOperands) {
|
||||
for (auto *operand : inputOperands) {
|
||||
if (!operand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
for (Value value : outputs) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), value,
|
||||
for (auto *operand : outputOperands) {
|
||||
if (!operand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
effects.emplace_back(MemoryEffects::Read::get(), operand->get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), value,
|
||||
effects.emplace_back(MemoryEffects::Write::get(), operand->get(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
}
|
||||
@@ -851,10 +857,8 @@ static void getGenericEffectsImpl(
|
||||
void GenericOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
||||
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
||||
outputBuffers);
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(),
|
||||
getInputOperands(), getOutputOperands());
|
||||
}
|
||||
|
||||
static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) {
|
||||
@@ -925,7 +929,7 @@ struct DeduplicateAndRemoveDeadOperandsAndResults
|
||||
|
||||
// Check if there is any change to operands.
|
||||
if (newInputOperands.size() + newOutputOperands.size() ==
|
||||
static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
|
||||
genericOp->getNumOperands())
|
||||
return failure();
|
||||
|
||||
// Create the new op with the body being empty.
|
||||
@@ -977,35 +981,34 @@ private:
|
||||
SmallVector<AffineMap> &newIndexingMaps) const {
|
||||
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
|
||||
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
|
||||
for (const auto &inputOpOperand :
|
||||
llvm::enumerate(genericOp.getInputOperands())) {
|
||||
for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
|
||||
OpOperand *inputOpOperand = en.value();
|
||||
// Check if operand is dead and if dropping the indexing map makes the
|
||||
// loops to shape computation invalid.
|
||||
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
|
||||
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) {
|
||||
// Add the current operands to the list of potentially droppable
|
||||
// operands. If it cannot be dropped, this needs to be popped back.
|
||||
droppedOpOperands.push_back(inputOpOperand.value());
|
||||
droppedOpOperands.push_back(inputOpOperand);
|
||||
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
|
||||
continue;
|
||||
droppedOpOperands.pop_back();
|
||||
}
|
||||
|
||||
// Check if this operand is a duplicate.
|
||||
AffineMap indexingMap =
|
||||
genericOp.getMatchingIndexingMap(inputOpOperand.value());
|
||||
AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand);
|
||||
auto it = dedupedInputs.find(
|
||||
std::make_pair(inputOpOperand.value()->get(), indexingMap));
|
||||
std::make_pair(inputOpOperand->get(), indexingMap));
|
||||
if (it != dedupedInputs.end()) {
|
||||
origToNewPos[inputOpOperand.index()] = it->second;
|
||||
droppedOpOperands.push_back(inputOpOperand.value());
|
||||
origToNewPos[en.index()] = it->second;
|
||||
droppedOpOperands.push_back(inputOpOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
// This is a preserved argument.
|
||||
origToNewPos[inputOpOperand.index()] = newInputOperands.size();
|
||||
dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
|
||||
origToNewPos[en.index()] = newInputOperands.size();
|
||||
dedupedInputs[{inputOpOperand->get(), indexingMap}] =
|
||||
newInputOperands.size();
|
||||
newInputOperands.push_back(inputOpOperand.value()->get());
|
||||
newInputOperands.push_back(inputOpOperand->get());
|
||||
newIndexingMaps.push_back(indexingMap);
|
||||
}
|
||||
return origToNewPos;
|
||||
@@ -1026,12 +1029,10 @@ private:
|
||||
// If the op doesnt have tensor semantics, keep all the outputs as
|
||||
// preserved.
|
||||
if (!genericOp.hasTensorSemantics()) {
|
||||
for (const auto &outputOpOperand :
|
||||
llvm::enumerate(genericOp.getOutputOperands())) {
|
||||
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
|
||||
newOutputOperands.push_back(outputOpOperand.value()->get());
|
||||
newIndexingMaps.push_back(
|
||||
genericOp.getMatchingIndexingMap(outputOpOperand.value()));
|
||||
for (const auto &en : llvm::enumerate(genericOp.getOutputOperands())) {
|
||||
origToNewPos[en.index()] = newOutputOperands.size();
|
||||
newOutputOperands.push_back(en.value()->get());
|
||||
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value()));
|
||||
}
|
||||
return origToNewPos;
|
||||
}
|
||||
@@ -1347,7 +1348,8 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
}
|
||||
|
||||
void MapOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, getInputs(), getOutputs());
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
|
||||
p << "(";
|
||||
@@ -1380,7 +1382,7 @@ LogicalResult MapOp::verify() {
|
||||
|
||||
// The shape of each input must match the shape of the output.
|
||||
auto outputShape =
|
||||
getOutputs().front().getType().cast<ShapedType>().getShape();
|
||||
getOutputOperand(0)->get().getType().cast<ShapedType>().getShape();
|
||||
for (Type inputArgType : TypeRange{getInputs()}) {
|
||||
auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
|
||||
if (inputElemShape != outputShape) {
|
||||
@@ -1409,10 +1411,8 @@ ArrayAttr MapOp::getIndexingMaps() {
|
||||
void MapOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
||||
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
||||
outputBuffers);
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(),
|
||||
getInputOperands(), getOutputOperands());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -1458,10 +1458,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
|
||||
void ReduceOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
||||
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
||||
outputBuffers);
|
||||
getGenericEffectsImpl(effects, getOperation()->getResults(),
|
||||
getInputOperands(), getOutputOperands());
|
||||
}
|
||||
|
||||
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
|
||||
@@ -1500,7 +1498,8 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
|
||||
}
|
||||
|
||||
void ReduceOp::print(OpAsmPrinter &p) {
|
||||
printCommonStructuredOpParts(p, getInputs(), getOutputs());
|
||||
printCommonStructuredOpParts(p, SmallVector<Value>(getInputOperands()),
|
||||
SmallVector<Value>(getOutputOperands()));
|
||||
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
|
||||
|
||||
@@ -1584,10 +1583,11 @@ LogicalResult ReduceOp::verify() {
|
||||
}
|
||||
|
||||
// Check that the last block arguments match the element type of the outputs.
|
||||
for (auto [output, bbArg] : llvm::zip(
|
||||
getOutputs(), block->getArguments().take_back(getNumOutputs()))) {
|
||||
for (auto [output, bbArg] :
|
||||
llvm::zip(getOutputOperands(),
|
||||
block->getArguments().take_back(getNumOutputs()))) {
|
||||
auto outputElementType =
|
||||
output.getType().cast<ShapedType>().getElementType();
|
||||
output->get().getType().cast<ShapedType>().getElementType();
|
||||
if (outputElementType != bbArg.getType())
|
||||
return emitOpError()
|
||||
<< "output element type " << outputElementType
|
||||
@@ -1751,14 +1751,14 @@ struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
|
||||
LogicalResult matchAndRewrite(LinalgOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
// Linalg "inputs" may be either tensor or memref type.
|
||||
// tensor<0xelt_type> is a convention that may not always mean
|
||||
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
||||
auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
|
||||
auto mt = opOperand.get().getType().dyn_cast<MemRefType>();
|
||||
if (!mt)
|
||||
continue;
|
||||
if (llvm::is_contained(op.getShape(opOperand), 0)) {
|
||||
if (llvm::is_contained(op.getShape(&opOperand), 0)) {
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -1774,10 +1774,10 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
PatternRewriter &rewriter) const override {
|
||||
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
||||
bool hasTensorCastOperand =
|
||||
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||
if (opOperand->get().isa<BlockArgument>())
|
||||
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
|
||||
if (opOperand.get().isa<BlockArgument>())
|
||||
return false;
|
||||
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
|
||||
return castOp && canFoldIntoConsumerOp(castOp);
|
||||
});
|
||||
if (!hasTensorCastOperand)
|
||||
@@ -1788,18 +1788,17 @@ struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
SmallVector<Value, 4> newOperands;
|
||||
newOperands.reserve(op->getNumOperands());
|
||||
// Inputs may fold.
|
||||
for (OpOperand *opOperand : op.getInputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
for (auto *input : op.getInputOperands()) {
|
||||
auto tensorCastOp = input->get().getDefiningOp<tensor::CastOp>();
|
||||
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
||||
? tensorCastOp.getSource()
|
||||
: opOperand->get());
|
||||
: input->get());
|
||||
}
|
||||
// Init tensors may fold, in which case the resultType must also change.
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
||||
for (auto *output : op.getOutputOperands()) {
|
||||
auto tensorCastOp = output->get().getDefiningOp<tensor::CastOp>();
|
||||
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
||||
: opOperand->get());
|
||||
newOperands.push_back(fold ? tensorCastOp.getOperand() : output->get());
|
||||
newResultTypes.push_back(newOperands.back().getType());
|
||||
}
|
||||
// Clone op.
|
||||
@@ -1858,8 +1857,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
||||
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
|
||||
Value newOperand =
|
||||
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
|
||||
SmallVector<Value> newOperands = linalgOp.getInputOperands();
|
||||
SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
|
||||
SmallVector<Value> newOperands{linalgOp.getInputOperands()};
|
||||
SmallVector<Value> outputOperands{linalgOp.getOutputOperands()};
|
||||
outputOperands[resultNumber] = newOperand;
|
||||
newOperands.append(outputOperands.begin(), outputOperands.end());
|
||||
|
||||
@@ -1882,14 +1881,14 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
||||
|
||||
/// For each of the operand in `operands` this function maps the static sizes of
|
||||
/// dimensions to their affine dim expressions.
|
||||
static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
|
||||
static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
|
||||
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
||||
for (OpOperand *opOperand : operands) {
|
||||
if (linalgOp.isScalar(opOperand))
|
||||
for (OpOperand &opOperand : operands) {
|
||||
if (linalgOp.isScalar(&opOperand))
|
||||
continue;
|
||||
Value src = opOperand->get();
|
||||
Value src = opOperand.get();
|
||||
auto sourceType = src.getType().cast<RankedTensorType>();
|
||||
auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
|
||||
auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
|
||||
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
||||
// `tensor.cast` operation and source of the cast operation has a static
|
||||
@@ -1932,7 +1931,7 @@ static void createNewOperandWithStaticSizes(
|
||||
return;
|
||||
auto sourceType = src.getType().cast<RankedTensorType>();
|
||||
Type resultType = sourceType;
|
||||
if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
|
||||
if (sourceType.hasStaticShape() && linalgOp.isOutput(opOperand)) {
|
||||
resultTypes.push_back(resultType);
|
||||
return;
|
||||
}
|
||||
@@ -1965,7 +1964,7 @@ static void createNewOperandWithStaticSizes(
|
||||
unsigned index = opOperand->getOperandNumber();
|
||||
newOperands[index] = newOperand;
|
||||
}
|
||||
if (linalgOp.isOutputTensor(opOperand))
|
||||
if (linalgOp.isOutput(opOperand))
|
||||
resultTypes.push_back(resultType);
|
||||
}
|
||||
|
||||
@@ -1992,8 +1991,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
|
||||
// For each of the affine dim expression, check if the size is known. If
|
||||
// known add that in the map.
|
||||
populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
|
||||
affineExprToSize);
|
||||
populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
|
||||
|
||||
SmallVector<Value> newOperands;
|
||||
SmallVector<Type> resultTypes;
|
||||
@@ -2001,12 +1999,12 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
||||
// `changeNeeded` is `false` if the operands of `linalgOp` require no
|
||||
// change in their types.
|
||||
bool changeNeeded = false;
|
||||
newOperands.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
newOperands.reserve(linalgOp->getNumOperands());
|
||||
resultTypes.reserve(linalgOp.getNumOutputs());
|
||||
|
||||
// Iterate over all the operands and update the static sizes.
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
createNewOperandWithStaticSizes(loc, rewriter, opOperand,
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
|
||||
affineExprToSize, linalgOp, newOperands,
|
||||
resultTypes, changeNeeded);
|
||||
}
|
||||
|
||||
@@ -112,14 +112,14 @@ struct BubbleUpExtractSliceOpPattern
|
||||
tileSizes[position] = sliceOp.getMixedSizes()[result.index()];
|
||||
}
|
||||
|
||||
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
|
||||
SmallVector<Value> valuesToTile = linalgOp->getOperands();
|
||||
SmallVector<Value> tiledOperands =
|
||||
makeTiledShapes(rewriter, linalgLoc, linalgOp, valuesToTile,
|
||||
tileOffsets, tileSizes, sizeBounds,
|
||||
/*omitPartialTileCheck=*/true);
|
||||
|
||||
SmallVector<Type, 4> resultTensorTypes;
|
||||
for (OpOperand *opOperand : linalgOp.getOutputTensorOperands())
|
||||
for (OpOperand *opOperand : linalgOp.getOutputOperands())
|
||||
resultTensorTypes.push_back(
|
||||
tiledOperands[opOperand->getOperandNumber()].getType());
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ struct LinalgOpInterface
|
||||
auto genericOp = cast<linalg::DestinationStyleOpInterface>(op);
|
||||
|
||||
// The i-th "out" tensor may alias with the i-th OpResult.
|
||||
if (genericOp.isOutputTensor(&opOperand))
|
||||
if (genericOp.isOutput(&opOperand))
|
||||
return {genericOp.getTiedOpResult(&opOperand)};
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -68,17 +68,17 @@ public:
|
||||
if (!outputType || !outputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) {
|
||||
return operand->get().getType().isa<ShapedType>();
|
||||
if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
|
||||
return input.getType().isa<ShapedType>();
|
||||
}))
|
||||
return failure();
|
||||
|
||||
// Make sure all element types are the same.
|
||||
auto getOperandElementType = [](OpOperand *operand) {
|
||||
return operand->get().getType().cast<ShapedType>().getElementType();
|
||||
auto getOperandElementType = [](Value value) {
|
||||
return value.getType().cast<ShapedType>().getElementType();
|
||||
};
|
||||
if (!llvm::all_equal(llvm::map_range(genericOp.getInputAndOutputOperands(),
|
||||
getOperandElementType)))
|
||||
if (!llvm::all_equal(
|
||||
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
|
||||
return failure();
|
||||
|
||||
// We can only handle the case where we have int/float elements.
|
||||
@@ -114,15 +114,15 @@ public:
|
||||
// All inputs should be constants.
|
||||
int numInputs = genericOp.getNumInputs();
|
||||
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
|
||||
for (const auto &operand : llvm::enumerate(genericOp.getInputOperands())) {
|
||||
if (!matchPattern(operand.value()->get(),
|
||||
m_Constant(&inputValues[operand.index()])))
|
||||
for (const auto &en : llvm::enumerate(genericOp.getInputOperands())) {
|
||||
if (!matchPattern(en.value()->get(),
|
||||
m_Constant(&inputValues[en.index()])))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Identified this as a potential candidate for folding. Now check the
|
||||
// policy to see whether we are allowed to proceed.
|
||||
for (auto *operand : genericOp.getInputOperands()) {
|
||||
for (OpOperand *operand : genericOp.getInputOperands()) {
|
||||
if (!controlFn(operand))
|
||||
return failure();
|
||||
}
|
||||
@@ -171,8 +171,8 @@ public:
|
||||
APIntOrFloatArray computeFnInputs;
|
||||
|
||||
auto inputShapes = llvm::to_vector<4>(
|
||||
llvm::map_range(genericOp.getInputOperands(), [](OpOperand *operand) {
|
||||
return operand->get().getType().cast<ShapedType>().getShape();
|
||||
llvm::map_range(genericOp.getInputs(), [](Value value) {
|
||||
return value.getType().cast<ShapedType>().getShape();
|
||||
}));
|
||||
|
||||
// Given a `linearIndex`, remap it to a linear index to access linalg op
|
||||
|
||||
@@ -194,7 +194,7 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp,
|
||||
}
|
||||
|
||||
/// Create the peeled generic op with an empty body.
|
||||
SmallVector<Value> outsOperands = genericOp.getOutputOperands();
|
||||
SmallVector<Value> outsOperands = genericOp.getOutputs();
|
||||
outsOperands.append(newInitValues.begin(), newInitValues.end());
|
||||
SmallVector<Type> resultTypes = llvm::to_vector(genericOp.getResultTypes());
|
||||
resultTypes.append(newResultTypes.begin(), newResultTypes.end());
|
||||
@@ -212,9 +212,7 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const {
|
||||
/// Append all results from the peeledGenericOps as `ins` operand for the
|
||||
/// residual generic op.
|
||||
SmallVector<Value> residualGenericOpOperands = llvm::to_vector(
|
||||
llvm::map_range(genericOp.getInputOperands(),
|
||||
[](OpOperand *operand) { return operand->get(); }));
|
||||
SmallVector<Value> residualGenericOpOperands = genericOp.getInputs();
|
||||
unsigned origNumResults = genericOp.getNumResults();
|
||||
unsigned peeledGenericOpNumResults = peeledGenericOp.getNumResults();
|
||||
SmallVector<Value> extraIns;
|
||||
|
||||
@@ -55,10 +55,9 @@ bool canBeDetensored(TensorType tensorType) {
|
||||
bool shouldBeDetensored(Operation *op, TypeConverter typeConverter) {
|
||||
GenericOp genericOp = dyn_cast_or_null<GenericOp>(op);
|
||||
return genericOp &&
|
||||
llvm::all_of(
|
||||
genericOp.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
||||
return !typeConverter.isLegal(opOperand->get().getType());
|
||||
});
|
||||
llvm::all_of(genericOp->getOpOperands(), [&](OpOperand &opOperand) {
|
||||
return !typeConverter.isLegal(opOperand.get().getType());
|
||||
});
|
||||
}
|
||||
|
||||
/// A conversion patttern for detensoring `linalg.generic` ops.
|
||||
|
||||
@@ -377,21 +377,21 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
|
||||
SmallVector<ArrayAttr> reassociationMaps;
|
||||
SmallVector<Type> newInputOutputTypes;
|
||||
bool doCanonicalization = false;
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
|
||||
for (OpOperand &opOperand : genericOp->getOpOperands()) {
|
||||
auto replacementInfo = replaceUnitExtents(genericOp, &opOperand, context);
|
||||
if (replacementInfo) {
|
||||
reassociationMaps.push_back(replacementInfo->reassociation);
|
||||
newIndexingMaps.push_back(replacementInfo->indexMap);
|
||||
newInputOutputTypes.push_back(replacementInfo->type);
|
||||
doCanonicalization |=
|
||||
replacementInfo->type != opOperand->get().getType();
|
||||
replacementInfo->type != opOperand.get().getType();
|
||||
} else {
|
||||
// If replaceUnitExtents cannot handle this case, maintain the same
|
||||
// type, indexing map, and create a set of mappings representing an
|
||||
// identity matrix.
|
||||
newInputOutputTypes.push_back(opOperand->get().getType());
|
||||
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(opOperand));
|
||||
int64_t origRank = genericOp.getRank(opOperand);
|
||||
newInputOutputTypes.push_back(opOperand.get().getType());
|
||||
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(&opOperand));
|
||||
int64_t origRank = genericOp.getRank(&opOperand);
|
||||
auto maps = llvm::to_vector<8>(llvm::map_range(
|
||||
llvm::seq<int64_t>(0, origRank), [&](int64_t dim) -> Attribute {
|
||||
return AffineMapAttr::get(
|
||||
|
||||
@@ -90,7 +90,7 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
|
||||
|
||||
// Only allow fusing the producer of an input operand for now.
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
if (!consumer.isInputTensor(fusedOperand))
|
||||
if (!consumer.isInput(fusedOperand))
|
||||
return false;
|
||||
|
||||
// Get the consumer index map. The number of results of the consumer index
|
||||
@@ -179,7 +179,7 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
|
||||
}
|
||||
}
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
assert(consumer.isInputTensor(fusedOperand) &&
|
||||
assert(consumer.isInput(fusedOperand) &&
|
||||
"expected producer of input operand");
|
||||
// 3. Consumer input operands up to consumerIdx (exclusive).
|
||||
for (BlockArgument bbArg : consumerBlock.getArguments().take_front(
|
||||
@@ -267,7 +267,7 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
|
||||
auto producer = cast<GenericOp>(producerResult.getOwner());
|
||||
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
|
||||
// TODO: allow fusing the producer of an output operand.
|
||||
assert(consumer.isInputTensor(fusedOperand) &&
|
||||
assert(consumer.isInput(fusedOperand) &&
|
||||
"expected producer of input operand");
|
||||
|
||||
// Compute the fused operands list and indexing maps.
|
||||
@@ -278,13 +278,14 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
|
||||
fusedOutputOperands.reserve(producer.getNumOutputs() +
|
||||
consumer.getNumOutputs());
|
||||
fusedResultTypes.reserve(producer.getNumOutputs() + consumer.getNumOutputs());
|
||||
fusedIndexMaps.reserve(producer.getNumInputsAndOutputs() +
|
||||
consumer.getNumInputsAndOutputs());
|
||||
fusedIndexMaps.reserve(producer->getNumOperands() +
|
||||
consumer->getNumOperands());
|
||||
// In the following, numbering matches that of `generateFusedTensorOpRegion`.
|
||||
// 3. Consumer input operands/maps up to consumerIdx (exclusive).
|
||||
SmallVector<OpOperand *> consumerInputs = consumer.getInputOperands();
|
||||
SmallVector<OpOperand *>::iterator it =
|
||||
llvm::find(consumerInputs, fusedOperand);
|
||||
auto consumerInputs = consumer.getInputOperands();
|
||||
auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) {
|
||||
return operand == fusedOperand;
|
||||
});
|
||||
assert(it != consumerInputs.end() && "expected to find the consumer operand");
|
||||
for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) {
|
||||
fusedInputOperands.push_back(opOperand->get());
|
||||
@@ -373,13 +374,13 @@ public:
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find the first operand that is defined by another generic op on tensors.
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
if (!areElementwiseOpsFusable(opOperand))
|
||||
for (OpOperand &opOperand : genericOp->getOpOperands()) {
|
||||
if (!areElementwiseOpsFusable(&opOperand))
|
||||
continue;
|
||||
if (!controlFn(opOperand))
|
||||
if (!controlFn(&opOperand))
|
||||
continue;
|
||||
|
||||
FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, opOperand);
|
||||
FailureOr<Operation *> fusedOp = fuseElementwiseOps(rewriter, &opOperand);
|
||||
if (succeeded(fusedOp)) {
|
||||
auto replacements =
|
||||
fusedOp.value()->getResults().take_back(genericOp.getNumResults());
|
||||
@@ -727,9 +728,9 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
|
||||
: collapsingReshapeOp.getSrc());
|
||||
continue;
|
||||
}
|
||||
if (genericOp.isInputTensor(opOperand)) {
|
||||
if (auto opOperandType =
|
||||
opOperand->get().getType().dyn_cast<RankedTensorType>()) {
|
||||
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
|
||||
auto opOperandType = opOperand->get().getType().cast<RankedTensorType>();
|
||||
RankedTensorType expandedOperandType =
|
||||
getExpandedType(opOperandType, indexingMap, expansionInfo);
|
||||
if (expandedOperandType != opOperand->get().getType()) {
|
||||
@@ -833,7 +834,7 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
tensor::CollapseShapeOp reshapeOp =
|
||||
opOperand->get().getDefiningOp<tensor::CollapseShapeOp>();
|
||||
if (!reshapeOp)
|
||||
@@ -1494,17 +1495,17 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
|
||||
for (OpOperand &opOperand : genericOp->getOpOperands()) {
|
||||
tensor::ExpandShapeOp reshapeOp =
|
||||
opOperand->get().getDefiningOp<tensor::ExpandShapeOp>();
|
||||
opOperand.get().getDefiningOp<tensor::ExpandShapeOp>();
|
||||
if (!reshapeOp)
|
||||
continue;
|
||||
|
||||
SmallVector<ReassociationIndices> collapsableIterationDims =
|
||||
getCollapsableIterationSpaceDims(genericOp, opOperand,
|
||||
getCollapsableIterationSpaceDims(genericOp, &opOperand,
|
||||
reshapeOp.getReassociationIndices());
|
||||
if (collapsableIterationDims.empty() ||
|
||||
!controlFoldingReshapes(opOperand)) {
|
||||
!controlFoldingReshapes(&opOperand)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -1614,7 +1615,7 @@ public:
|
||||
SmallVector<AffineMap> fusedIndexMaps;
|
||||
SmallVector<Value> fusedOperands;
|
||||
SmallVector<Location> fusedLocs{genericOp.getLoc()};
|
||||
fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs());
|
||||
fusedIndexMaps.reserve(genericOp->getNumOperands());
|
||||
fusedOperands.reserve(genericOp.getNumInputs());
|
||||
fusedLocs.reserve(fusedLocs.size() + genericOp.getNumInputs());
|
||||
for (OpOperand *inputOperand : genericOp.getInputOperands()) {
|
||||
@@ -1640,7 +1641,7 @@ public:
|
||||
Value scalarConstant = rewriter.create<arith::ConstantOp>(
|
||||
def->getLoc(), constantAttr, constantAttr.getType());
|
||||
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputs();
|
||||
auto fusedOp = rewriter.create<GenericOp>(
|
||||
rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(),
|
||||
/*inputs=*/fusedOperands,
|
||||
|
||||
@@ -68,7 +68,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
bool fromSubViewOpOnly = false) {
|
||||
// Iterate over the inputs and outputs in order.
|
||||
// Extract the subranges from the linearized ranges.
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
// The method `getRangeFromOperandShape` requires using SubViewOp or
|
||||
// ExtractSliceOps. If the value isn't defined from there continue.
|
||||
// todo: The method should be adapted to get the values from
|
||||
@@ -77,12 +77,12 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
// `std` dialect and add the method to `ViewInterface`.
|
||||
if (fromSubViewOpOnly &&
|
||||
!isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>(
|
||||
opOperand->get().getDefiningOp()))
|
||||
opOperand.get().getDefiningOp()))
|
||||
continue;
|
||||
|
||||
AffineMap map = op.getMatchingIndexingMap(opOperand);
|
||||
AffineMap map = op.getMatchingIndexingMap(&opOperand);
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
|
||||
<< opOperand->getOperandNumber() << "\n");
|
||||
<< opOperand.getOperandNumber() << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "getShapeDefiningLoopRange map: " << map << "\n");
|
||||
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
|
||||
@@ -94,8 +94,8 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
|
||||
<< loopDepth << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
|
||||
<< opOperand->get() << "\n");
|
||||
return ShapeDimension{opOperand->get(),
|
||||
<< opOperand.get() << "\n");
|
||||
return ShapeDimension{opOperand.get(),
|
||||
static_cast<unsigned>(en.index())};
|
||||
}
|
||||
}
|
||||
@@ -104,7 +104,7 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
}
|
||||
|
||||
static SmallVector<Value> getTiledOperands(LinalgOp producer) {
|
||||
return producer.getInputAndOutputOperands();
|
||||
return producer->getOperands();
|
||||
}
|
||||
|
||||
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
|
||||
@@ -137,7 +137,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
||||
}
|
||||
|
||||
SmallVector<Value, 8> clonedShapes;
|
||||
clonedShapes.reserve(producer.getNumInputsAndOutputs());
|
||||
clonedShapes.reserve(producer->getNumOperands());
|
||||
|
||||
// Compute subranges for all tensor input/output operands.
|
||||
clonedShapes.append(makeTiledShapes(
|
||||
@@ -150,15 +150,18 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
||||
// fully dynamic at construction time.
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
resultTypes.reserve(producer->getNumResults());
|
||||
for (RankedTensorType t : producer.getOutputTensorTypes()) {
|
||||
unsigned rank = t.getRank();
|
||||
for (OpOperand *operand : producer.getOutputOperands()) {
|
||||
auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
unsigned rank = tensorType.getRank();
|
||||
SmallVector<int64_t, 4> staticOffsetsVector(
|
||||
rank, ShapedType::kDynamicStrideOrOffset);
|
||||
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
|
||||
SmallVector<int64_t, 4> staticStridesVector(
|
||||
rank, ShapedType::kDynamicStrideOrOffset);
|
||||
resultTypes.push_back(tensor::ExtractSliceOp::inferResultType(
|
||||
t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
|
||||
tensorType, staticOffsetsVector, staticSizesVector,
|
||||
staticStridesVector));
|
||||
}
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
|
||||
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
|
||||
}
|
||||
erase_value(tileIvs, OpFoldResult());
|
||||
SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
|
||||
SmallVector<Value> tiledOperands = producerOp->getOperands();
|
||||
tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
|
||||
tileSizes, producerLoopBounds,
|
||||
/**omitPartialTileCheck=*/false);
|
||||
|
||||
@@ -50,19 +50,19 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter,
|
||||
if (failed(generalizeNamedOpPrecondition(linalgOp)))
|
||||
return rewriter.notifyMatchFailure(linalgOp, "preconditions not met");
|
||||
|
||||
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
|
||||
SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
|
||||
SmallVector<Value> inputs = linalgOp.getInputOperands();
|
||||
SmallVector<Value> outputs = linalgOp.getOutputOperands();
|
||||
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
|
||||
SmallVector<StringRef> iterators = linalgOp.getIteratorTypesArray();
|
||||
SmallVector<RankedTensorType> resultTypes = linalgOp.getOutputTensorTypes();
|
||||
SmallVector<Type> types(resultTypes.begin(), resultTypes.end());
|
||||
SmallVector<Type> resultTypes = linalgOp.hasTensorSemantics()
|
||||
? TypeRange(ValueRange(outputs))
|
||||
: TypeRange{};
|
||||
|
||||
// All named ops have a region attached that can be inlined.
|
||||
assert(linalgOp->getNumRegions() == 1 &&
|
||||
"expect named op to have one region attached");
|
||||
GenericOp genericOp =
|
||||
rewriter.create<GenericOp>(linalgOp.getLoc(), types, inputOperands,
|
||||
outputOperands, indexingMaps, iterators);
|
||||
GenericOp genericOp = rewriter.create<GenericOp>(
|
||||
linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators);
|
||||
rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(),
|
||||
genericOp.getRegion().begin());
|
||||
rewriter.replaceOp(linalgOp, genericOp->getResults());
|
||||
|
||||
@@ -111,7 +111,7 @@ private:
|
||||
static bool isOnlyUsedAsInputOfLinalgOp(tensor::PadOp padOp) {
|
||||
for (OpOperand &use : padOp.getResult().getUses()) {
|
||||
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
||||
if (!linalgUser || !linalgUser.isInputTensor(&use)) {
|
||||
if (!linalgUser || !linalgUser.isInput(&use)) {
|
||||
LLVM_DEBUG(DBGS() << "Found a use of " << *(padOp)
|
||||
<< "\nthat is not an input tensor of a LinalgOp, "
|
||||
<< "cannot hoist\n"
|
||||
|
||||
@@ -43,7 +43,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
|
||||
SmallVector<Value> newOperands;
|
||||
for (OpOperand *opOperand : genericOp.getInputOperands()) {
|
||||
AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
|
||||
if (genericOp.isInputTensor(opOperand) && map.isConstant()) {
|
||||
if (genericOp.isInput(opOperand) && map.isConstant()) {
|
||||
scalarOperands.emplace_back(opOperand->getOperandNumber());
|
||||
} else {
|
||||
newIndexingMaps.emplace_back(map);
|
||||
@@ -58,7 +58,7 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
|
||||
newIndexingMaps.emplace_back(genericOp.getMatchingIndexingMap(opOperand));
|
||||
|
||||
Location loc = genericOp->getLoc();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputOperands();
|
||||
SmallVector<Value> outputOperands = genericOp.getOutputs();
|
||||
auto newOp = rewriter.create<GenericOp>(
|
||||
loc, genericOp->getResultTypes(), newOperands, outputOperands,
|
||||
newIndexingMaps, genericOp.getIteratorTypesArray());
|
||||
|
||||
@@ -67,8 +67,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
||||
|
||||
// 2. Compute the interchanged indexing maps.
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
|
||||
AffineMap m = genericOp.getMatchingIndexingMap(opOperand);
|
||||
for (OpOperand &opOperand : genericOp->getOpOperands()) {
|
||||
AffineMap m = genericOp.getMatchingIndexingMap(&opOperand);
|
||||
if (!permutationMap.isEmpty())
|
||||
m = m.compose(permutationMap);
|
||||
newIndexingMaps.push_back(m);
|
||||
|
||||
@@ -131,7 +131,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
|
||||
assert(linalgOp.hasBufferSemantics() &&
|
||||
"expected linalg op with buffer semantics");
|
||||
SmallVector<Value> indexedValues;
|
||||
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
indexedValues.reserve(linalgOp->getNumOperands());
|
||||
|
||||
auto allIvsPlusDims = SmallVector<Value>(allIvs.begin(), allIvs.end());
|
||||
|
||||
@@ -161,7 +161,9 @@ static void emitScalarImplementation(OpBuilder &b, Location loc,
|
||||
// 3. Emit store.
|
||||
SmallVector<SmallVector<Value>, 8> indexing;
|
||||
SmallVector<Value> outputBuffers;
|
||||
for (OpOperand *outputOperand : linalgOp.getOutputBufferOperands()) {
|
||||
for (OpOperand *outputOperand : linalgOp.getOutputOperands()) {
|
||||
if (!outputOperand->get().getType().isa<MemRefType>())
|
||||
continue;
|
||||
indexing.push_back(makeCanonicalAffineApplies(
|
||||
b, loc, linalgOp.getMatchingIndexingMap(outputOperand),
|
||||
allIvsPlusDims));
|
||||
|
||||
@@ -145,15 +145,15 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
|
||||
assert(linalgOp.hasBufferSemantics() && "revisit usage of shaped operand");
|
||||
auto vUseFullTileBuffers =
|
||||
options.useFullTileBuffers.value_or(llvm::SmallBitVector());
|
||||
vUseFullTileBuffers.resize(linalgOp.getNumInputsAndOutputs(),
|
||||
vUseFullTileBuffers.resize(linalgOp->getNumOperands(),
|
||||
options.useFullTileBuffersDefault);
|
||||
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
int64_t operandNumber = opOperand->getOperandNumber();
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
int64_t operandNumber = opOperand.getOperandNumber();
|
||||
if (options.operandsToPromote &&
|
||||
!options.operandsToPromote->count(operandNumber))
|
||||
continue;
|
||||
Operation *op = opOperand->get().getDefiningOp();
|
||||
Operation *op = opOperand.get().getDefiningOp();
|
||||
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
|
||||
subViews[operandNumber] = sv;
|
||||
useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
|
||||
@@ -326,13 +326,13 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
|
||||
// operands are not views. This is to support cases such as FillOp taking
|
||||
// extra scalars etc. Keep a reference to output buffers;
|
||||
SmallVector<Value, 8> opViews;
|
||||
opViews.reserve(op.getNumInputsAndOutputs());
|
||||
opViews.reserve(op->getNumOperands());
|
||||
SmallVector<std::pair<Value, Value>, 8> writebackViews;
|
||||
writebackViews.reserve(promotedBuffersAndViews->size());
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
int64_t operandNumber = opOperand->getOperandNumber();
|
||||
for (OpOperand &opOperand : op->getOpOperands()) {
|
||||
int64_t operandNumber = opOperand.getOperandNumber();
|
||||
if (options.subViews.count(operandNumber) != 0) {
|
||||
if (options.useFullTileBuffers[opOperand->get()])
|
||||
if (options.useFullTileBuffers[opOperand.get()])
|
||||
opViews.push_back(
|
||||
(*promotedBuffersAndViews)[operandNumber].fullLocalView);
|
||||
else
|
||||
@@ -340,10 +340,10 @@ promoteSubViews(ImplicitLocOpBuilder &b, LinalgOp op,
|
||||
(*promotedBuffersAndViews)[operandNumber].partialLocalView);
|
||||
if (operandNumber >= op.getNumInputs())
|
||||
writebackViews.emplace_back(std::make_pair(
|
||||
opOperand->get(),
|
||||
opOperand.get(),
|
||||
(*promotedBuffersAndViews)[operandNumber].partialLocalView));
|
||||
} else {
|
||||
opViews.push_back(opOperand->get());
|
||||
opViews.push_back(opOperand.get());
|
||||
}
|
||||
}
|
||||
op->setOperands(0, opViews.size(), opViews);
|
||||
@@ -371,12 +371,12 @@ mlir::linalg::promoteSubviewsPrecondition(Operation *op,
|
||||
if (!linalgOp || !linalgOp.hasBufferSemantics())
|
||||
return failure();
|
||||
// Check that at least one of the requested operands is indeed a subview.
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
auto sv =
|
||||
isa_and_nonnull<memref::SubViewOp>(opOperand->get().getDefiningOp());
|
||||
isa_and_nonnull<memref::SubViewOp>(opOperand.get().getDefiningOp());
|
||||
if (sv) {
|
||||
if (!options.operandsToPromote ||
|
||||
options.operandsToPromote->count(opOperand->getOperandNumber()))
|
||||
options.operandsToPromote->count(opOperand.getOperandNumber()))
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,7 +214,6 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
||||
// from the previous op.
|
||||
unsigned intermRank = newOutputShape.size();
|
||||
AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
|
||||
SmallVector<Value> outputOperands = op.getOutputOperands();
|
||||
SmallVector<StringRef> reductionIteratorTypes;
|
||||
SmallVector<AffineExpr> exprs;
|
||||
for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
|
||||
@@ -230,7 +229,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
|
||||
|
||||
auto reduction = b.create<GenericOp>(
|
||||
loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
|
||||
outputOperands, reductionMaps, reductionIteratorTypes,
|
||||
SmallVector<Value>{op.getOutputOperands()}, reductionMaps,
|
||||
reductionIteratorTypes,
|
||||
[reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
|
||||
Operation *clonedReductionOp = b.clone(*reductionOp);
|
||||
clonedReductionOp->setOperand(0, inputs[0]);
|
||||
@@ -341,8 +341,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
||||
SmallVector<Operation *> emptyOrAllocTensorOps;
|
||||
SmallVector<linalg::FillOp> fillOps;
|
||||
fillOps.reserve(op.getNumOutputs());
|
||||
for (auto it : llvm::zip(op.getOutputs(), neutralElements)) {
|
||||
Value rankedTensor = std::get<0>(it);
|
||||
for (auto it : llvm::zip(op.getOutputOperands(), neutralElements)) {
|
||||
Value rankedTensor = std::get<0>(it)->get();
|
||||
auto t = rankedTensor.getType().cast<RankedTensorType>();
|
||||
RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
|
||||
reductionDimSize / splitFactor, insertSplitDimension);
|
||||
@@ -366,7 +366,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
||||
// Step 2. Reindex / expand indexing maps.
|
||||
// Reindex existing input indexings: k -> k * splitFactor + k'.
|
||||
SmallVector<AffineMap> newMaps;
|
||||
newMaps.reserve(op.getNumInputsAndOutputs() + 1);
|
||||
newMaps.reserve(op->getNumOperands() + 1);
|
||||
for (OpOperand *o : op.getInputOperands())
|
||||
newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
|
||||
// Provision a new indexing for the shape-only tensor.
|
||||
@@ -384,7 +384,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
||||
|
||||
// Step 3. Handle operands.
|
||||
// Compute the new input tensors.
|
||||
auto newInputs = llvm::to_vector<4>(op.getInputs());
|
||||
SmallVector<Value> newInputs(op.getInputOperands());
|
||||
// Add a single shape-only tensor to carry the dimensions without resorting to
|
||||
// more complex inversions.
|
||||
newInputs.push_back(b.create<tensor::EmptyOp>(
|
||||
@@ -413,10 +413,10 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
|
||||
// TODO: all results can be handled in a single GenericOp, when
|
||||
// multi-reduction support is available.
|
||||
SmallVector<LinalgOp> results;
|
||||
for (auto it :
|
||||
llvm::zip(genericOp->getResults(), op.getOutputs(), combinerOps)) {
|
||||
for (auto it : llvm::zip(genericOp->getResults(), op.getOutputOperands(),
|
||||
combinerOps)) {
|
||||
Value reindexedOutput = std::get<0>(it);
|
||||
Value originalOutput = std::get<1>(it);
|
||||
Value originalOutput = std::get<1>(it)->get();
|
||||
auto originalOutputType = originalOutput.getType().cast<RankedTensorType>();
|
||||
Operation *combinerOp = std::get<2>(it);
|
||||
|
||||
|
||||
@@ -503,7 +503,7 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
|
||||
// Tile the `operandValuesToUse` that either match the `op` operands
|
||||
// themselves or the tile loop arguments forwarding them.
|
||||
assert(operandValuesToUse.size() ==
|
||||
static_cast<size_t>(op.getNumInputsAndOutputs()) &&
|
||||
static_cast<size_t>(op->getNumOperands()) &&
|
||||
"expect the number of operands and inputs and outputs to match");
|
||||
SmallVector<Value> valuesToTile = operandValuesToUse;
|
||||
SmallVector<OpFoldResult> sizeBounds =
|
||||
|
||||
@@ -125,14 +125,12 @@ struct LinalgOpTilingInterface
|
||||
// specified could lead to out of bounds accesses.
|
||||
Location loc = op->getLoc();
|
||||
LinalgOp linalgOp = cast<LinalgOp>(op);
|
||||
SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
|
||||
SmallVector<Value> valuesToTile = linalgOp->getOperands();
|
||||
SmallVector<Value, 4> tiledOperands = makeTiledShapes(
|
||||
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
|
||||
|
||||
SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
|
||||
linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
|
||||
return tiledOperands[opOperand->getOperandNumber()].getType();
|
||||
}));
|
||||
SmallVector<Type> resultTensorTypes =
|
||||
getTensorOutputTypes(linalgOp, tiledOperands);
|
||||
|
||||
Operation *tiledOp =
|
||||
linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
|
||||
@@ -222,23 +220,23 @@ struct LinalgOpTilingInterface
|
||||
return op->emitOpError("expected operation to have buffer semantics");
|
||||
|
||||
SmallVector<Value> indexedValues;
|
||||
indexedValues.reserve(linalgOp.getNumInputsAndOutputs());
|
||||
indexedValues.reserve(linalgOp->getNumOperands());
|
||||
Location linalgOpLoc = op->getLoc();
|
||||
/// Load the data corresponding to the block arguments that
|
||||
/// represent input operands.
|
||||
for (OpOperand *operand : linalgOp.getInputAndOutputOperands()) {
|
||||
if (!linalgOp.payloadUsesValueFromOperand(operand)) {
|
||||
for (OpOperand &operand : linalgOp->getOpOperands()) {
|
||||
if (!linalgOp.payloadUsesValueFromOperand(&operand)) {
|
||||
indexedValues.push_back(nullptr);
|
||||
continue;
|
||||
}
|
||||
if (linalgOp.isScalar(operand)) {
|
||||
indexedValues.push_back(operand->get());
|
||||
if (linalgOp.isScalar(&operand)) {
|
||||
indexedValues.push_back(operand.get());
|
||||
continue;
|
||||
}
|
||||
SmallVector<Value> indices = getIndicesForAccess(
|
||||
builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(operand), ivs);
|
||||
builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs);
|
||||
Value load =
|
||||
builder.create<memref::LoadOp>(linalgOpLoc, operand->get(), indices);
|
||||
builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices);
|
||||
indexedValues.push_back(load);
|
||||
}
|
||||
|
||||
|
||||
@@ -203,10 +203,10 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
|
||||
b.setInsertionPointAfter(opToPad);
|
||||
// Make a copy of the shaped operands and update it.
|
||||
SmallVector<Value> newOperands;
|
||||
newOperands.reserve(opToPad.getNumInputsAndOutputs());
|
||||
for (OpOperand *opOperand : opToPad.getInputAndOutputOperands()) {
|
||||
newOperands.reserve(opToPad->getNumOperands());
|
||||
for (OpOperand &opOperand : opToPad->getOpOperands()) {
|
||||
FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox(
|
||||
b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
|
||||
b, opToPad, &opOperand, paddingDimensions, paddingValues, packPaddings);
|
||||
// Exit if `paddingDimensions` cannot be bounded statically.
|
||||
if (failed(paddedOperand))
|
||||
return failure();
|
||||
@@ -327,15 +327,15 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
|
||||
|
||||
// Hoist the padding.
|
||||
for (const auto &en : enumerate(options.hoistPaddings)) {
|
||||
if (static_cast<int64_t>(en.index()) >= paddedOp.getNumInputsAndOutputs())
|
||||
if (static_cast<int64_t>(en.index()) >= paddedOp->getNumOperands())
|
||||
break;
|
||||
OpOperand *opOperand = &paddedOp->getOpOperand(en.index());
|
||||
auto padOp = opOperand->get().getDefiningOp<tensor::PadOp>();
|
||||
OpOperand &opOperand = paddedOp->getOpOperand(en.index());
|
||||
auto padOp = opOperand.get().getDefiningOp<tensor::PadOp>();
|
||||
if (!padOp || en.value() == 0)
|
||||
continue;
|
||||
|
||||
// Fail hoisting if the operand shape is not fully static.
|
||||
if (llvm::any_of(paddedOp.getShape(opOperand), ShapedType::isDynamic))
|
||||
if (llvm::any_of(paddedOp.getShape(&opOperand), ShapedType::isDynamic))
|
||||
return failure();
|
||||
|
||||
tensor::PadOp hoistedOp;
|
||||
|
||||
@@ -459,35 +459,35 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
|
||||
// 3. Turn all BBArgs into vector.transfer_read / load.
|
||||
Location loc = linalgOp.getLoc();
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
BlockArgument bbarg = block->getArgument(opOperand->getOperandNumber());
|
||||
if (linalgOp.isScalar(opOperand)) {
|
||||
bvm.map(bbarg, opOperand->get());
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
BlockArgument bbarg = block->getArgument(opOperand.getOperandNumber());
|
||||
if (linalgOp.isScalar(&opOperand)) {
|
||||
bvm.map(bbarg, opOperand.get());
|
||||
continue;
|
||||
}
|
||||
VectorType readType;
|
||||
AffineMap map;
|
||||
// TODO: can we keep this simplification?
|
||||
// if (linalgOp.getShape(opOperand).empty()) {
|
||||
// if (linalgOp.getShape(&opOperand).empty()) {
|
||||
// readType = VectorType::get({}, bbarg.getType());
|
||||
// } else {
|
||||
if (opOperand->getOperandNumber() < linalgOp.getNumInputs()) {
|
||||
if (opOperand.getOperandNumber() < linalgOp.getNumInputs()) {
|
||||
map = inverseAndBroadcastProjectedPermutation(
|
||||
linalgOp.getMatchingIndexingMap(opOperand));
|
||||
linalgOp.getMatchingIndexingMap(&opOperand));
|
||||
readType = VectorType::get(commonVectorShape,
|
||||
getElementTypeOrSelf(opOperand->get()));
|
||||
getElementTypeOrSelf(opOperand.get()));
|
||||
} else {
|
||||
map = inversePermutation(
|
||||
reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
|
||||
readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
|
||||
getElementTypeOrSelf(opOperand->get()));
|
||||
reindexIndexingMap(linalgOp.getMatchingIndexingMap(&opOperand)));
|
||||
readType = VectorType::get(map.compose(linalgOp.getShape(&opOperand)),
|
||||
getElementTypeOrSelf(opOperand.get()));
|
||||
}
|
||||
// }
|
||||
|
||||
auto shape = linalgOp.getShape(opOperand);
|
||||
auto shape = linalgOp.getShape(&opOperand);
|
||||
SmallVector<Value> indices(shape.size(), zero);
|
||||
Value readValue = b.create<vector::TransferReadOp>(
|
||||
loc, readType, opOperand->get(), indices, map);
|
||||
loc, readType, opOperand.get(), indices, map);
|
||||
// Not all ops support 0-d vectors, extract the scalar for now.
|
||||
// TODO: remove this.
|
||||
if (readValue.getType().cast<VectorType>().getRank() == 0)
|
||||
@@ -495,7 +495,7 @@ vectorizeAsLinalgGeneric(OpBuilder &b, LinalgOp linalgOp,
|
||||
|
||||
LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
|
||||
bvm.map(bbarg, readValue);
|
||||
bvm.map(opOperand->get(), readValue);
|
||||
bvm.map(opOperand.get(), readValue);
|
||||
}
|
||||
|
||||
SmallVector<CustomVectorizationHook> hooks;
|
||||
@@ -1342,9 +1342,9 @@ struct Conv1DGenerator : public StructuredGenerator<LinalgOp> {
|
||||
// Determine whether `linalgOp` can be generated with this generator
|
||||
if (linalgOp.getNumInputs() != 2 || linalgOp.getNumOutputs() != 1)
|
||||
return;
|
||||
lhsShaped = linalgOp.getInputs()[0];
|
||||
rhsShaped = linalgOp.getInputs()[1];
|
||||
resShaped = linalgOp.getOutputs()[0];
|
||||
lhsShaped = linalgOp.getInputOperand(0)->get();
|
||||
rhsShaped = linalgOp.getInputOperand(1)->get();
|
||||
resShaped = linalgOp.getOutputOperand(0)->get();
|
||||
lhsShapedType = lhsShaped.getType().dyn_cast<ShapedType>();
|
||||
rhsShapedType = rhsShaped.getType().dyn_cast<ShapedType>();
|
||||
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
|
||||
|
||||
@@ -490,17 +490,18 @@ void GenerateLoopNest<scf::ForOp>::doit(
|
||||
assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
|
||||
"expected as many entries for proc info as number of loops, even if "
|
||||
"they are null entries");
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
|
||||
? SmallVector<Value>{}
|
||||
: linalgOp.getOutputOperands();
|
||||
|
||||
SmallVector<Value, 4> lbs, ubs, steps;
|
||||
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
|
||||
LoopNest loopNest = mlir::scf::buildLoopNest(
|
||||
b, loc, lbs, ubs, steps, iterArgInitValues,
|
||||
[&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
|
||||
assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
|
||||
assert(iterArgs.size() == iterArgInitValues.size() &&
|
||||
"expect the number of output tensors and iter args to match");
|
||||
SmallVector<Value> operandValuesToUse =
|
||||
linalgOp.getInputAndOutputOperands();
|
||||
SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
|
||||
if (!iterArgs.empty()) {
|
||||
operandValuesToUse = linalgOp.getInputOperands();
|
||||
operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
|
||||
@@ -530,7 +531,9 @@ void GenerateLoopNest<AffineForOp>::doit(
|
||||
ValueRange)>
|
||||
bodyBuilderFn,
|
||||
ArrayRef<linalg::ProcInfo> /*procInfo*/) {
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
|
||||
? SmallVector<Value>{}
|
||||
: linalgOp.getOutputOperands();
|
||||
assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
|
||||
SmallVector<Value, 4> lbs, ubs, steps;
|
||||
unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
|
||||
@@ -546,9 +549,8 @@ void GenerateLoopNest<AffineForOp>::doit(
|
||||
|
||||
mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
|
||||
[&](OpBuilder &b, Location loc, ValueRange ivs) {
|
||||
SmallVector<Value> operandValuesToUse =
|
||||
linalgOp.getInputAndOutputOperands();
|
||||
bodyBuilderFn(b, loc, ivs, operandValuesToUse);
|
||||
bodyBuilderFn(b, loc, ivs,
|
||||
linalgOp->getOperands());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -695,7 +697,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
|
||||
ValueRange)>
|
||||
bodyBuilderFn,
|
||||
ArrayRef<linalg::ProcInfo> procInfo) {
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.getOutputTensorOperands();
|
||||
SmallVector<Value> iterArgInitValues = linalgOp.hasBufferSemantics()
|
||||
? SmallVector<Value>{}
|
||||
: linalgOp.getOutputOperands();
|
||||
assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
|
||||
// This function may be passed more iterator types than ranges.
|
||||
assert(iteratorTypes.size() >= loopRanges.size() &&
|
||||
@@ -725,9 +729,7 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
|
||||
generateParallelLoopNest(
|
||||
b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
|
||||
[&](OpBuilder &b, Location loc, ValueRange ivs) {
|
||||
SmallVector<Value> operandValuesToUse =
|
||||
linalgOp.getInputAndOutputOperands();
|
||||
bodyBuilderFn(b, loc, ivs, operandValuesToUse);
|
||||
bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
|
||||
},
|
||||
ivs);
|
||||
|
||||
@@ -905,10 +907,10 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
|
||||
}
|
||||
|
||||
SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
if (op.hasBufferSemantics())
|
||||
return {};
|
||||
return llvm::to_vector(
|
||||
llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
|
||||
llvm::map_range(op.getOutputOperands(), [&](OpOperand *opOperand) {
|
||||
return operands[opOperand->getOperandNumber()].getType();
|
||||
}));
|
||||
}
|
||||
@@ -916,11 +918,13 @@ SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
|
||||
SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
||||
LinalgOp op, ValueRange operands,
|
||||
ValueRange results) {
|
||||
if (op.hasBufferSemantics())
|
||||
return {};
|
||||
SmallVector<Value> tensorResults;
|
||||
tensorResults.reserve(results.size());
|
||||
// Insert a insert_slice for each output tensor.
|
||||
unsigned resultIdx = 0;
|
||||
for (OpOperand *opOperand : op.getOutputTensorOperands()) {
|
||||
for (OpOperand *opOperand : op.getOutputOperands()) {
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
Value outputTensor = operands[opOperand->getOperandNumber()];
|
||||
@@ -958,23 +962,26 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
|
||||
computeTileSizes(builder, loc, tileSizes, sizeBounds);
|
||||
|
||||
assert(static_cast<int64_t>(valuesToTile.size()) ==
|
||||
linalgOp.getNumInputsAndOutputs() &&
|
||||
linalgOp->getNumOperands() &&
|
||||
"expected one value to tile for every operand");
|
||||
SmallVector<Optional<SliceParameters>> allSliceParams;
|
||||
allSliceParams.reserve(valuesToTile.size());
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
Value shapedOp = valuesToTile[opOperand->getOperandNumber()];
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
Value shapedOp = valuesToTile[opOperand.getOperandNumber()];
|
||||
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
|
||||
AffineMap map = linalgOp.getMatchingIndexingMap(opOperand);
|
||||
AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
|
||||
// Use `opOperand` as is if it is not tiled and not an output tensor. Having
|
||||
// an extract/insert slice pair for all output tensors simplifies follow up
|
||||
// transformations such as padding and bufferization since the
|
||||
// extract/insert slice pairs make the accessed iteration argument
|
||||
// subdomains explicit.
|
||||
if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) {
|
||||
|
||||
Type operandType = opOperand.get().getType();
|
||||
if (!isTiled(map, tileSizes) && !(operandType.isa<RankedTensorType>() &&
|
||||
linalgOp.isOutput(&opOperand))) {
|
||||
allSliceParams.push_back(llvm::None);
|
||||
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: "
|
||||
<< opOperand->get().getType() << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< ": not tiled: use shape: " << operandType << "\n");
|
||||
continue;
|
||||
}
|
||||
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
|
||||
|
||||
@@ -105,8 +105,7 @@ static bool isZeroYield(GenericOp op) {
|
||||
auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
|
||||
if (auto arg = yieldOp.getOperand(0).dyn_cast<BlockArgument>()) {
|
||||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()];
|
||||
return isZeroValue(t->get());
|
||||
return isZeroValue(op->getOperand(arg.getArgNumber()));
|
||||
}
|
||||
}
|
||||
return isZeroValue(yieldOp.getOperand(0));
|
||||
@@ -242,8 +241,8 @@ public:
|
||||
return failure();
|
||||
// Modify operand structure of producer and consumer.
|
||||
Location loc = prod.getLoc();
|
||||
SmallVector<Value> inputOps = prod.getInputOperands();
|
||||
SmallVector<Value> outputOps = op.getOutputOperands();
|
||||
SmallVector<Value> inputOps = prod.getInputs();
|
||||
SmallVector<Value> outputOps = op.getOutputs();
|
||||
SmallVector<AffineMap> fusedIndexMaps = prod.getIndexingMapsArray();
|
||||
inputOps.push_back(op.getInputOperand(1 - other)->get());
|
||||
fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other
|
||||
|
||||
@@ -194,14 +194,14 @@ static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a,
|
||||
/// no annotations are found or inadmissible constructs occur.
|
||||
static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
|
||||
bool annotated = false;
|
||||
for (OpOperand *t : op.getInputAndOutputOperands()) {
|
||||
auto map = op.getMatchingIndexingMap(t);
|
||||
auto enc = getSparseTensorEncoding(t->get().getType());
|
||||
for (OpOperand &t : op->getOpOperands()) {
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
auto enc = getSparseTensorEncoding(t.get().getType());
|
||||
if (enc)
|
||||
annotated = true;
|
||||
assert(map.getNumResults() == op.getRank(t));
|
||||
assert(map.getNumResults() == op.getRank(&t));
|
||||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
unsigned tensor = t->getOperandNumber();
|
||||
unsigned tensor = t.getOperandNumber();
|
||||
AffineExpr a = map.getResult(toOrigDim(enc, d));
|
||||
if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d)))
|
||||
return false; // inadmissible affine expression
|
||||
@@ -291,13 +291,13 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
|
||||
std::vector<unsigned> inDegree(n, 0); // in-degree of each node.
|
||||
auto iteratorTypes = op.getIteratorTypesArray();
|
||||
// Iterate over the indexing maps of every tensor in the tensor expression.
|
||||
for (OpOperand *t : op.getInputAndOutputOperands()) {
|
||||
for (OpOperand &t : op->getOpOperands()) {
|
||||
// Skip tensor during cycle resolution.
|
||||
if (t == skip)
|
||||
if (&t == skip)
|
||||
continue;
|
||||
// Get map and encoding.
|
||||
auto map = op.getMatchingIndexingMap(t);
|
||||
auto enc = getSparseTensorEncoding(t->get().getType());
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
auto enc = getSparseTensorEncoding(t.get().getType());
|
||||
assert(map.getNumDims() == n);
|
||||
// Skip dense tensor constraints when not requested.
|
||||
if (!(mask & SortMask::kIncludeDense) && !enc)
|
||||
@@ -314,7 +314,7 @@ static bool computeIterationGraph(Merger &merger, linalg::GenericOp op,
|
||||
// Push unrelated loops into sparse iteration space, so these
|
||||
// will be skipped more often.
|
||||
if (mask & SortMask::kIncludeUndef) {
|
||||
unsigned tensor = t->getOperandNumber();
|
||||
unsigned tensor = t.getOperandNumber();
|
||||
for (unsigned i = 0; i < n; i++)
|
||||
if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) ||
|
||||
merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) {
|
||||
@@ -534,16 +534,16 @@ static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder,
|
||||
static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
linalg::GenericOp op) {
|
||||
Location loc = op.getLoc();
|
||||
assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1);
|
||||
assert(op->getNumOperands() == op.getNumInputs() + 1);
|
||||
// For every tensor, find lower and upper bound on dimensions, set the
|
||||
// same bounds on loop indices, and obtain dense or sparse buffer(s).
|
||||
auto dynShape = {ShapedType::kDynamicSize};
|
||||
SmallVector<Value, 4> args;
|
||||
for (OpOperand *t : op.getInputAndOutputOperands()) {
|
||||
unsigned tensor = t->getOperandNumber();
|
||||
auto shape = op.getShape(t);
|
||||
auto map = op.getMatchingIndexingMap(t);
|
||||
auto enc = getSparseTensorEncoding(t->get().getType());
|
||||
for (OpOperand &t : op->getOpOperands()) {
|
||||
unsigned tensor = t.getOperandNumber();
|
||||
auto shape = op.getShape(&t);
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
auto enc = getSparseTensorEncoding(t.get().getType());
|
||||
// Scan all dimensions of current tensor.
|
||||
args.clear();
|
||||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
@@ -560,23 +560,23 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
|
||||
auto dim = builder.getIndexAttr(d);
|
||||
codegen.pointers[tensor][idx] =
|
||||
builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim);
|
||||
builder.create<ToPointersOp>(loc, ptrTp, t.get(), dim);
|
||||
codegen.indices[tensor][idx] =
|
||||
builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
|
||||
builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
|
||||
} else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) {
|
||||
// Singleton dimension, fetch indices.
|
||||
auto indTp =
|
||||
MemRefType::get(dynShape, getIndexOverheadType(builder, enc));
|
||||
auto dim = builder.getIndexAttr(d);
|
||||
codegen.indices[tensor][idx] =
|
||||
builder.create<ToIndicesOp>(loc, indTp, t->get(), dim);
|
||||
builder.create<ToIndicesOp>(loc, indTp, t.get(), dim);
|
||||
} else {
|
||||
// Dense dimension, nothing to fetch.
|
||||
assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense));
|
||||
}
|
||||
// Find upper bound in current dimension.
|
||||
unsigned p = toOrigDim(enc, d);
|
||||
Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p);
|
||||
Value up = linalg::createOrFoldDimOp(builder, loc, t.get(), p);
|
||||
if (ShapedType::isDynamic(shape[p]))
|
||||
args.push_back(up);
|
||||
assert(codegen.highs[tensor][idx] == nullptr);
|
||||
@@ -585,21 +585,21 @@ static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
// Perform the required bufferization. Dense inputs materialize
|
||||
// from the input tensors. Dense outputs need special handling.
|
||||
// Sparse inputs use sparse primitives to obtain the values.
|
||||
Type elementType = getElementTypeOrSelf(t->get().getType());
|
||||
Type elementType = getElementTypeOrSelf(t.get().getType());
|
||||
if (!enc) {
|
||||
// Non-annotated dense tensors.
|
||||
auto denseTp = MemRefType::get(shape, elementType);
|
||||
if (tensor < op.getNumInputs())
|
||||
codegen.buffers[tensor] =
|
||||
builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get());
|
||||
builder.create<bufferization::ToMemrefOp>(loc, denseTp, t.get());
|
||||
else
|
||||
codegen.buffers[tensor] =
|
||||
genOutputBuffer(codegen, builder, op, denseTp, args);
|
||||
} else if (t != codegen.sparseOut) {
|
||||
} else if (&t != codegen.sparseOut) {
|
||||
// Annotated sparse tensors (not involved in output).
|
||||
auto sparseTp = MemRefType::get(dynShape, elementType);
|
||||
codegen.buffers[tensor] =
|
||||
builder.create<ToValuesOp>(loc, sparseTp, t->get());
|
||||
builder.create<ToValuesOp>(loc, sparseTp, t.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -845,15 +845,15 @@ static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
return val;
|
||||
}
|
||||
// Load during insertion.
|
||||
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
|
||||
if (t == codegen.sparseOut) {
|
||||
OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
|
||||
if (&t == codegen.sparseOut) {
|
||||
if (codegen.redCustom != -1u)
|
||||
return genInsertionLoadReduce(merger, codegen, builder, op, t);
|
||||
return genInsertionLoad(codegen, builder, op, t);
|
||||
return genInsertionLoadReduce(merger, codegen, builder, op, &t);
|
||||
return genInsertionLoad(codegen, builder, op, &t);
|
||||
}
|
||||
// Actual load.
|
||||
SmallVector<Value, 4> args;
|
||||
Value ptr = genSubscript(codegen, builder, op, t, args);
|
||||
Value ptr = genSubscript(codegen, builder, op, &t, args);
|
||||
if (codegen.curVecLength > 1)
|
||||
return genVectorLoad(codegen, builder, ptr, args);
|
||||
return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
|
||||
@@ -1093,9 +1093,9 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
if (merger.exp(exp).kind == Kind::kTensor) {
|
||||
// Inspect tensor indices.
|
||||
bool atLevel = ldx == -1u;
|
||||
OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor];
|
||||
auto map = op.getMatchingIndexingMap(t);
|
||||
auto enc = getSparseTensorEncoding(t->get().getType());
|
||||
OpOperand &t = op->getOpOperand(merger.exp(exp).tensor);
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
auto enc = getSparseTensorEncoding(t.get().getType());
|
||||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
AffineExpr a = map.getResult(toOrigDim(enc, d));
|
||||
if (!isInvariantAffine(codegen, a, ldx, atLevel))
|
||||
@@ -1105,7 +1105,7 @@ static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder,
|
||||
if (!atLevel)
|
||||
return;
|
||||
OpOperand *lhs = op.getOutputOperand(0);
|
||||
if (lhs == t) {
|
||||
if (lhs == &t) {
|
||||
// Start or end a scalarized reduction
|
||||
if (atStart) {
|
||||
Kind kind = merger.exp(last).kind;
|
||||
@@ -1288,9 +1288,9 @@ static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction,
|
||||
/// This prevents effective vectorization.
|
||||
static bool denseUnitStrides(Merger &merger, linalg::GenericOp op,
|
||||
unsigned idx) {
|
||||
for (OpOperand *t : op.getInputAndOutputOperands()) {
|
||||
if (!getSparseTensorEncoding(t->get().getType())) {
|
||||
auto map = op.getMatchingIndexingMap(t);
|
||||
for (OpOperand &t : op->getOpOperands()) {
|
||||
if (!getSparseTensorEncoding(t.get().getType())) {
|
||||
auto map = op.getMatchingIndexingMap(&t);
|
||||
for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) {
|
||||
AffineExpr a = map.getResult(d);
|
||||
// Report non-unit stride if innermost index appears at an outer
|
||||
@@ -1856,7 +1856,7 @@ public:
|
||||
// information for all tensors to loop indices in the kernel.
|
||||
if (op.getNumOutputs() != 1)
|
||||
return failure();
|
||||
unsigned numTensors = op.getNumInputsAndOutputs();
|
||||
unsigned numTensors = op->getNumOperands();
|
||||
unsigned numLoops = op.getNumLoops();
|
||||
Merger merger(numTensors, numLoops);
|
||||
if (!findSparseAnnotations(merger, op))
|
||||
|
||||
@@ -910,10 +910,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
|
||||
// argument is considered a tensor, indexed by the implicit loop
|
||||
// bounds. This includes rank-0 tensor arguments.
|
||||
if (arg.getOwner()->getParentOp() == op) {
|
||||
OpOperand *t = op.getInputAndOutputOperands()[argN];
|
||||
if (!op.isScalar(t))
|
||||
OpOperand &t = op->getOpOperand(argN);
|
||||
if (!op.isScalar(&t))
|
||||
return addExp(kTensor, argN);
|
||||
v = t->get(); // get scalar value
|
||||
v = t.get(); // get scalar value
|
||||
}
|
||||
// Any other argument (marked as scalar argument for the generic op
|
||||
// or belonging to an enveloping op) is considered invariant.
|
||||
|
||||
@@ -275,7 +275,7 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remove_deadargs_generic_basic
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: %[[GENERIC_OP:.*]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[ARG0]] : tensor<?xf32>)
|
||||
// CHECK-SAME: outs({{.*}} : tensor<?xf32>) {
|
||||
|
||||
@@ -121,26 +121,6 @@ func.func @generic(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offset: ?>>
|
||||
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
|
||||
func.func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
|
||||
%arg1: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
|
||||
%cst = arith.constant 0.0 : f32
|
||||
linalg.generic #trait_0
|
||||
ins(%arg0, %cst : tensor<?x?xvector<3x4xi4>>, f32)
|
||||
outs(%arg1 : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
|
||||
attrs = {foo = 1} {
|
||||
^bb(%0: vector<3x4xi4>, %1: f32, %2: f32) :
|
||||
linalg.yield %1 : f32
|
||||
}
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @generic_with_tensor_input
|
||||
// CHECK: linalg.generic {
|
||||
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
|
||||
// CHECK-SAME: library_call = "some_external_function_name_1"}
|
||||
// CHECK-SAME: ins({{.*}}, {{.*}} : tensor<?x?xvector<3x4xi4>>, f32)
|
||||
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>)
|
||||
// CHECK-SAME: {foo = 1 : i64}
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
@@ -300,27 +280,19 @@ func.func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, strided<[?, 1], offs
|
||||
|
||||
func.func @named_ops(%a3: memref<?x?x?xf32>, %b3: memref<?x?x?xf32>, %c3: memref<?x?x?xf32>,
|
||||
%ta3: tensor<?x?x?xf32>, %tb3: tensor<?x?x?xf32>, %tc3: tensor<?x?x?xf32>)
|
||||
-> (tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
-> (tensor<?x?x?xf32>)
|
||||
{
|
||||
linalg.batch_matmul ins(%a3, %b3: memref<?x?x?xf32>, memref<?x?x?xf32>)
|
||||
outs(%c3: memref<?x?x?xf32>)
|
||||
linalg.batch_matmul ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs(%c3: memref<?x?x?xf32>)
|
||||
%res1 = linalg.batch_matmul
|
||||
ins(%ta3, %tb3: tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
outs(%tc3: tensor<?x?x?xf32>)
|
||||
-> tensor<?x?x?xf32>
|
||||
%res2 = linalg.batch_matmul
|
||||
ins(%ta3, %b3: tensor<?x?x?xf32>, memref<?x?x?xf32>)
|
||||
outs(%tc3: tensor<?x?x?xf32>)
|
||||
-> tensor<?x?x?xf32>
|
||||
return %res1, %res2 : tensor<?x?x?xf32>, tensor<?x?x?xf32>
|
||||
return %res1 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @named_ops
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK: linalg.batch_matmul
|
||||
// CHECK: linalg.batch_matmul
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ static void addOperands(Operation *op, SetVector<Value> &operandSet) {
|
||||
return;
|
||||
TypeSwitch<Operation *, void>(op)
|
||||
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
|
||||
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
|
||||
SmallVector<Value> inputOperands{linalgOp.getInputOperands()};
|
||||
operandSet.insert(inputOperands.begin(), inputOperands.end());
|
||||
})
|
||||
.Default([&](Operation *operation) {
|
||||
@@ -147,7 +147,7 @@ struct TestLinalgElementwiseFusion
|
||||
if (expandOp->hasOneUse()) {
|
||||
OpOperand &use = *expandOp->getUses().begin();
|
||||
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
|
||||
if (linalgOp && linalgOp.isOutputTensor(&use))
|
||||
if (linalgOp && linalgOp.isOutput(&use))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -38,14 +38,14 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
|
||||
// Tile and Fuse for tensors inputs (TODO: all tensor operands).
|
||||
bool changed = false;
|
||||
for (LinalgOp linalgOp : llvm::reverse(linalgOps)) {
|
||||
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
||||
if (opOperand->get().getType().isa<MemRefType>()) {
|
||||
for (OpOperand &opOperand : linalgOp->getOpOperands()) {
|
||||
if (opOperand.get().getType().isa<MemRefType>()) {
|
||||
// TODO: LinalgDependenceGraph should be able to update itself.
|
||||
// The current naive and expensive reconstruction of the graph should be
|
||||
// removed.
|
||||
linalg::Aliases aliases;
|
||||
linalg::LinalgDependenceGraph graph(aliases, linalgOps);
|
||||
auto info = fuseProducerOfBuffer(b, *opOperand, graph);
|
||||
auto info = fuseProducerOfBuffer(b, opOperand, graph);
|
||||
if (failed(info))
|
||||
continue;
|
||||
auto *originalOp = info->originalProducer.getOperation();
|
||||
@@ -54,11 +54,11 @@ static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) {
|
||||
std::find(linalgOps.begin(), linalgOps.end(), originalOp);
|
||||
*originalOpInLinalgOpsVector = info->fusedProducer.getOperation();
|
||||
changed = true;
|
||||
} else if (opOperand->get().getType().isa<RankedTensorType>()) {
|
||||
} else if (opOperand.get().getType().isa<RankedTensorType>()) {
|
||||
// Tile and Fuse tensor input.
|
||||
if (opOperand->getOperandNumber() >= linalgOp.getNumInputs())
|
||||
if (opOperand.getOperandNumber() >= linalgOp.getNumInputs())
|
||||
continue;
|
||||
auto info = fuseProducerOfTensor(b, *opOperand);
|
||||
auto info = fuseProducerOfTensor(b, opOperand);
|
||||
if (failed(info))
|
||||
continue;
|
||||
auto *originalOp = info->originalProducer.getOperation();
|
||||
|
||||
@@ -2835,9 +2835,10 @@ def TestLinalgConvOp :
|
||||
return "";
|
||||
}
|
||||
|
||||
// To conform with interface requirement on operand naming.
|
||||
mlir::ValueRange inputs() { return getInputs(); }
|
||||
mlir::ValueRange outputs() { return getOutputs(); }
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {getNumOperands - 1, getNumOperands};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -2894,9 +2895,10 @@ def TestLinalgFillOp :
|
||||
return "";
|
||||
}
|
||||
|
||||
// To conform with interface requirement on operand naming.
|
||||
mlir::ValueRange inputs() { return getInputs(); }
|
||||
mlir::ValueRange outputs() { return getOutputs(); }
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {getNumOperands - 1, getNumOperands};
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
@@ -563,6 +563,11 @@ def {0} : LinalgStructuredBase_Op<"{1}", !listconcat([AttrSizedOperandSegments],
|
||||
return regionBuilder;
|
||||
}
|
||||
|
||||
std::pair<int64_t, int64_t> getOutputsPositionRange() {{
|
||||
int64_t getNumOperands = this->getNumOperands();
|
||||
return {{getNumOperands - 1, getNumOperands};
|
||||
}
|
||||
|
||||
// Generic methods.
|
||||
static unsigned getNumRegionArgs();
|
||||
std::string getLibraryCallName();
|
||||
@@ -638,8 +643,8 @@ ArrayAttr {0}::getIndexingMaps() {{
|
||||
AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(
|
||||
getNumParallelLoops(), context);
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
for (OpOperand *opOperand : getInputAndOutputOperands())
|
||||
indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
|
||||
for (OpOperand &opOperand : getOperation()->getOpOperands())
|
||||
indexingMaps.push_back(getRank(&opOperand) == 0 ? scalarMap : tensorMap);
|
||||
return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
|
||||
}
|
||||
)FMT";
|
||||
@@ -654,10 +659,9 @@ LogicalResult {0}::fold(ArrayRef<Attribute>,
|
||||
}
|
||||
void {0}::getEffects(SmallVectorImpl<
|
||||
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
|
||||
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
||||
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
||||
if (hasTensorSemantics()) return;
|
||||
getGenericEffectsImpl(effects,
|
||||
getOperation()->getResults(), inputBuffers, outputBuffers);
|
||||
getOperation()->getResults(), getInputOperands(), getOutputOperands());
|
||||
}
|
||||
)FMT";
|
||||
|
||||
|
||||
Reference in New Issue
Block a user