mirror of
https://github.com/intel/llvm.git
synced 2026-01-26 12:26:52 +08:00
[MLIR] Add bufferization state to getBufferType and resolveConflicts interface methods (#141466)
The PR continues the work started in #141019 by adding the `BufferizationState` class also to the `getBufferType` and `resolveConflicts` interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.
This commit is contained in:
@@ -598,13 +598,14 @@ private:
|
||||
FailureOr<Value>
|
||||
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
|
||||
const BufferizationOptions &options,
|
||||
bool copy = true);
|
||||
const BufferizationState &state, bool copy = true);
|
||||
|
||||
/// Lookup the buffer for the given value. If the value was not bufferized
|
||||
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
|
||||
/// from which the memref operand is returned.
|
||||
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
|
||||
const BufferizationOptions &options);
|
||||
const BufferizationOptions &options,
|
||||
const BufferizationState &state);
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization
|
||||
/// without bufferizing any IR.
|
||||
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
|
||||
///
|
||||
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
|
||||
FailureOr<BaseMemRefType> getBufferType(Value value,
|
||||
const BufferizationOptions &options);
|
||||
const BufferizationOptions &options,
|
||||
const BufferizationState &state);
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization
|
||||
/// without bufferizing any IR. This function (and not the other overload
|
||||
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
|
||||
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
|
||||
FailureOr<BaseMemRefType> getBufferType(Value value,
|
||||
const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack);
|
||||
|
||||
/// Return "true" if the given op has tensor semantics and should be bufferized.
|
||||
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
|
||||
/// places.
|
||||
FailureOr<BaseMemRefType>
|
||||
defaultGetBufferType(Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack);
|
||||
|
||||
/// This is the default implementation of
|
||||
|
||||
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
/*retType=*/"::llvm::LogicalResult",
|
||||
/*methodName=*/"resolveConflicts",
|
||||
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
|
||||
"const ::mlir::bufferization::AnalysisState &":$state),
|
||||
"const ::mlir::bufferization::AnalysisState &":$analysisState,
|
||||
"const ::mlir::bufferization::BufferizationState &":$bufferizationState),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
auto bufferizableOp =
|
||||
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
|
||||
return bufferizableOp.resolveTensorOpOperandConflicts(
|
||||
rewriter, state);
|
||||
rewriter, analysisState, bufferizationState);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -528,6 +529,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
/*methodName=*/"getBufferType",
|
||||
/*args=*/(ins "::mlir::Value":$value,
|
||||
"const ::mlir::bufferization::BufferizationOptions &":$options,
|
||||
"const ::mlir::bufferization::BufferizationState &":$state,
|
||||
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
|
||||
/*methodBody=*/"",
|
||||
/*defaultImplementation=*/[{
|
||||
@@ -536,7 +538,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
assert(invocationStack.back() == value &&
|
||||
"inconsistant invocation stack");
|
||||
return ::mlir::bufferization::detail::defaultGetBufferType(
|
||||
value, options, invocationStack);
|
||||
value, options, state, invocationStack);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@@ -621,7 +623,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||
/// form of `bufferization.alloc_tensor` ops.
|
||||
::llvm::LogicalResult resolveTensorOpOperandConflicts(
|
||||
::mlir::RewriterBase &rewriter,
|
||||
const ::mlir::bufferization::AnalysisState &state);
|
||||
const ::mlir::bufferization::AnalysisState &analysisState,
|
||||
const ::mlir::bufferization::BufferizationState &bufferizationState);
|
||||
|
||||
/// Return `true` if the given OpOperand creates an alias but does neither
|
||||
/// read nor write. This implies that `bufferizesToMemoryRead` and
|
||||
|
||||
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
|
||||
|
||||
FailureOr<BaseMemRefType> getBufferType(
|
||||
Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack);
|
||||
|
||||
RankedTensorType getType() {
|
||||
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
|
||||
|
||||
FailureOr<BaseMemRefType> getBufferType(
|
||||
Value value, const BufferizationOptions &options,
|
||||
SmallVector<Value> &invocationStack) {
|
||||
const BufferizationState &state, SmallVector<Value> &invocationStack) {
|
||||
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
|
||||
}
|
||||
}];
|
||||
|
||||
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
// Note: The user may want to override this function for OpResults in
|
||||
// case the bufferized result type is different from the bufferized type of
|
||||
// the aliasing OpOperand (if any).
|
||||
if (isa<OpResult>(value))
|
||||
return bufferization::detail::defaultGetBufferType(value, options,
|
||||
return bufferization::detail::defaultGetBufferType(value, options, state,
|
||||
invocationStack);
|
||||
|
||||
// Compute the buffer type of the block argument by computing the bufferized
|
||||
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
|
||||
callerType = memrefType;
|
||||
} else {
|
||||
FailureOr<BaseMemRefType> maybeCallerType =
|
||||
bufferization::getBufferType(opOperand->get(), options,
|
||||
bufferization::getBufferType(opOperand->get(), options, state,
|
||||
invocationStack);
|
||||
if (failed(maybeCallerType))
|
||||
return failure();
|
||||
@@ -81,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
|
||||
if (bufferType == callerType)
|
||||
continue;
|
||||
|
||||
// If the computed buffer type does not match the computed buffer type
|
||||
// of the earlier forwarded operands, fall back to a buffer type with a
|
||||
// fully dynamic layout map.
|
||||
// If the computed buffer type does not match the computed buffer type
|
||||
// of the earlier forwarded operands, fall back to a buffer type with a
|
||||
// fully dynamic layout map.
|
||||
#ifndef NDEBUG
|
||||
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
|
||||
assert(bufferType.hasRank() && callerType.hasRank() &&
|
||||
|
||||
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
|
||||
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
|
||||
/// computed with `BufferizableOpIntercace::getBufferType`.
|
||||
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options);
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state);
|
||||
|
||||
} // namespace bufferization
|
||||
} // namespace mlir
|
||||
|
||||
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
|
||||
/// additional buffer allocations.
|
||||
LogicalResult insertTensorCopies(Operation *op,
|
||||
const OneShotBufferizationOptions &options,
|
||||
const BufferizationState &bufferizationState,
|
||||
BufferizationStatistics *statistics = nullptr);
|
||||
|
||||
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
|
||||
/// After applying this transform, the IR can be bufferized without inserting
|
||||
/// additional buffer allocations.
|
||||
LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
|
||||
LogicalResult insertTensorCopies(Operation *op,
|
||||
const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState);
|
||||
|
||||
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
|
||||
/// ops.
|
||||
|
||||
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
|
||||
auto castOp = cast<arith::IndexCastOp>(op);
|
||||
auto resultTensorType = cast<TensorType>(castOp.getType());
|
||||
|
||||
FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
|
||||
FailureOr<Value> source =
|
||||
getBuffer(rewriter, castOp.getIn(), options, state);
|
||||
if (failed(source))
|
||||
return failure();
|
||||
auto sourceType = cast<BaseMemRefType>(source->getType());
|
||||
@@ -151,9 +152,9 @@ struct SelectOpInterface
|
||||
// the moment (one for each tensor). When copying the op result, only one
|
||||
// copy would be needed.
|
||||
FailureOr<Value> maybeTrueBuffer =
|
||||
getBuffer(rewriter, selectOp.getTrueValue(), options);
|
||||
getBuffer(rewriter, selectOp.getTrueValue(), options, state);
|
||||
FailureOr<Value> maybeFalseBuffer =
|
||||
getBuffer(rewriter, selectOp.getFalseValue(), options);
|
||||
getBuffer(rewriter, selectOp.getFalseValue(), options, state);
|
||||
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
|
||||
return failure();
|
||||
Value trueBuffer = *maybeTrueBuffer;
|
||||
@@ -164,7 +165,7 @@ struct SelectOpInterface
|
||||
// both of them to the most dynamic MemRef type.
|
||||
if (trueBuffer.getType() != falseBuffer.getType()) {
|
||||
auto targetType =
|
||||
bufferization::getBufferType(selectOp.getResult(), options);
|
||||
bufferization::getBufferType(selectOp.getResult(), options, state);
|
||||
if (failed(targetType))
|
||||
return failure();
|
||||
if (trueBuffer.getType() != *targetType)
|
||||
@@ -182,13 +183,14 @@ struct SelectOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto selectOp = cast<arith::SelectOp>(op);
|
||||
assert(value == selectOp.getResult() && "invalid value");
|
||||
auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
|
||||
options, invocationStack);
|
||||
auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
|
||||
options, invocationStack);
|
||||
auto trueType = bufferization::getBufferType(
|
||||
selectOp.getTrueValue(), options, state, invocationStack);
|
||||
auto falseType = bufferization::getBufferType(
|
||||
selectOp.getFalseValue(), options, state, invocationStack);
|
||||
if (failed(trueType) || failed(falseType))
|
||||
return failure();
|
||||
if (*trueType == *falseType)
|
||||
|
||||
@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
|
||||
/// allocated.
|
||||
FailureOr<Value> bufferization::allocateTensorForShapedValue(
|
||||
OpBuilder &b, Location loc, Value shapedValue,
|
||||
const BufferizationOptions &options, bool copy) {
|
||||
const BufferizationOptions &options, const BufferizationState &state,
|
||||
bool copy) {
|
||||
Value tensor;
|
||||
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
|
||||
tensor = shapedValue;
|
||||
@@ -210,7 +211,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
|
||||
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
|
||||
if (copy)
|
||||
return allocTensorOp.getResult();
|
||||
FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
|
||||
FailureOr<BaseMemRefType> copyBufferType =
|
||||
getBufferType(tensor, options, state);
|
||||
if (failed(copyBufferType))
|
||||
return failure();
|
||||
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
|
||||
@@ -222,7 +224,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
|
||||
}
|
||||
|
||||
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
RewriterBase &rewriter, const AnalysisState &state) {
|
||||
RewriterBase &rewriter, const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
Operation *op = getOperation();
|
||||
SmallVector<OpOperand *> outOfPlaceOpOperands;
|
||||
@@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
Type operandType = opOperand.get().getType();
|
||||
if (!llvm::isa<TensorType>(operandType))
|
||||
continue;
|
||||
if (state.isInPlace(opOperand))
|
||||
if (analysisState.isInPlace(opOperand))
|
||||
continue;
|
||||
if (llvm::isa<UnrankedTensorType>(operandType))
|
||||
return op->emitError("copying of unranked tensors is not implemented");
|
||||
|
||||
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
|
||||
AliasingValueList aliasingValues =
|
||||
analysisState.getAliasingValues(opOperand);
|
||||
if (aliasingValues.getNumAliases() == 1 &&
|
||||
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
|
||||
!state.bufferizesToMemoryWrite(opOperand) &&
|
||||
state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
|
||||
!analysisState.bufferizesToMemoryWrite(opOperand) &&
|
||||
analysisState
|
||||
.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
|
||||
.getNumAliases() == 1 &&
|
||||
!isa<UnrankedTensorType>(
|
||||
aliasingValues.getAliases()[0].value.getType())) {
|
||||
@@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
// cannot be copied at the moment).
|
||||
Value value = aliasingValues.getAliases()[0].value;
|
||||
outOfPlaceValues.push_back(value);
|
||||
if (!state.canOmitTensorCopy(opOperand))
|
||||
if (!analysisState.canOmitTensorCopy(opOperand))
|
||||
copiedOpValues.insert(value);
|
||||
} else {
|
||||
// In all other cases, make a copy of the OpOperand.
|
||||
outOfPlaceOpOperands.push_back(&opOperand);
|
||||
if (!state.canOmitTensorCopy(opOperand))
|
||||
if (!analysisState.canOmitTensorCopy(opOperand))
|
||||
copiedOpOperands.insert(&opOperand);
|
||||
}
|
||||
}
|
||||
@@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
rewriter.setInsertionPoint(op);
|
||||
for (OpOperand *opOperand : outOfPlaceOpOperands) {
|
||||
FailureOr<Value> copy = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
|
||||
copiedOpOperands.contains(opOperand));
|
||||
rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
|
||||
bufferizationState, copiedOpOperands.contains(opOperand));
|
||||
if (failed(copy))
|
||||
return failure();
|
||||
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
|
||||
@@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
|
||||
rewriter.setInsertionPointAfter(op);
|
||||
for (Value value : outOfPlaceValues) {
|
||||
FailureOr<Value> copy = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), value, state.getOptions(),
|
||||
copiedOpValues.count(value));
|
||||
rewriter, op->getLoc(), value, analysisState.getOptions(),
|
||||
bufferizationState, copiedOpValues.count(value));
|
||||
if (failed(copy))
|
||||
return failure();
|
||||
SmallVector<OpOperand *> uses = llvm::to_vector(
|
||||
@@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
|
||||
}
|
||||
|
||||
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
||||
const BufferizationOptions &options) {
|
||||
const BufferizationOptions &options,
|
||||
const BufferizationState &state) {
|
||||
#ifndef NDEBUG
|
||||
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
|
||||
assert(tensorType && "unexpected non-tensor type");
|
||||
@@ -678,7 +684,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
||||
// Insert to_buffer op.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
setInsertionPointAfter(rewriter, value);
|
||||
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
|
||||
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
|
||||
if (failed(memrefType))
|
||||
return failure();
|
||||
ensureToBufferOpIsValid(value, *memrefType);
|
||||
@@ -689,14 +695,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
FailureOr<BaseMemRefType>
|
||||
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
||||
bufferization::getBufferType(Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state) {
|
||||
SmallVector<Value> invocationStack;
|
||||
return getBufferType(value, options, invocationStack);
|
||||
return getBufferType(value, options, state, invocationStack);
|
||||
}
|
||||
|
||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
FailureOr<BaseMemRefType>
|
||||
bufferization::getBufferType(Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) {
|
||||
assert(llvm::isa<TensorType>(value.getType()) &&
|
||||
"unexpected non-tensor type");
|
||||
@@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
|
||||
Operation *op = getOwnerOfValue(value);
|
||||
auto bufferizableOp = options.dynCastBufferizableOp(op);
|
||||
if (bufferizableOp)
|
||||
return bufferizableOp.getBufferType(value, options, invocationStack);
|
||||
return bufferizableOp.getBufferType(value, options, state, invocationStack);
|
||||
|
||||
// Op is not bufferizable.
|
||||
auto memSpace =
|
||||
@@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
|
||||
|
||||
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
|
||||
Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &bufferizationState,
|
||||
SmallVector<Value> &invocationStack) {
|
||||
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
|
||||
|
||||
@@ -954,14 +963,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
|
||||
// Value is an OpResult.
|
||||
Operation *op = getOwnerOfValue(value);
|
||||
auto opResult = llvm::cast<OpResult>(value);
|
||||
AnalysisState state(options);
|
||||
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
|
||||
AnalysisState analysisState(options);
|
||||
AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
|
||||
if (aliases.getNumAliases() > 0 &&
|
||||
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
|
||||
// If the OpResult has an equivalent OpOperand, both OpResult and
|
||||
// OpOperand bufferize to the exact same buffer type.
|
||||
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
|
||||
return getBufferType(equivalentOperand, options, invocationStack);
|
||||
return getBufferType(equivalentOperand, options, bufferizationState,
|
||||
invocationStack);
|
||||
}
|
||||
|
||||
// If we do not know the memory space and there is no default memory space,
|
||||
|
||||
@@ -163,14 +163,15 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
|
||||
// Get "copy" buffer.
|
||||
Value copyBuffer;
|
||||
if (getCopy()) {
|
||||
FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
|
||||
FailureOr<Value> maybeCopyBuffer =
|
||||
getBuffer(rewriter, getCopy(), options, state);
|
||||
if (failed(maybeCopyBuffer))
|
||||
return failure();
|
||||
copyBuffer = *maybeCopyBuffer;
|
||||
}
|
||||
|
||||
// Create memory allocation.
|
||||
auto allocType = bufferization::getBufferType(getResult(), options);
|
||||
auto allocType = bufferization::getBufferType(getResult(), options, state);
|
||||
if (failed(allocType))
|
||||
return failure();
|
||||
SmallVector<Value> dynamicDims = getDynamicSizes();
|
||||
@@ -223,6 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) {
|
||||
assert(value == getResult() && "invalid value");
|
||||
|
||||
@@ -231,8 +233,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
|
||||
if (getMemorySpace().has_value()) {
|
||||
memorySpace = *getMemorySpace();
|
||||
} else if (getCopy()) {
|
||||
auto copyBufferType =
|
||||
bufferization::getBufferType(getCopy(), options, invocationStack);
|
||||
auto copyBufferType = bufferization::getBufferType(getCopy(), options,
|
||||
state, invocationStack);
|
||||
if (failed(copyBufferType))
|
||||
return failure();
|
||||
memorySpace = copyBufferType->getMemorySpace();
|
||||
@@ -532,7 +534,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state) {
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
|
||||
@@ -583,7 +585,8 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
|
||||
bool tensorDest = isa<TensorType>(getDest().getType());
|
||||
Value buffer;
|
||||
if (tensorDest) {
|
||||
FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, getDest(), options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
buffer = *maybeBuffer;
|
||||
@@ -591,7 +594,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
|
||||
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
|
||||
buffer = getDest();
|
||||
}
|
||||
auto srcBuffer = getBuffer(rewriter, getSource(), options);
|
||||
auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
|
||||
if (failed(srcBuffer))
|
||||
return failure();
|
||||
if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
|
||||
|
||||
@@ -280,8 +280,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||
BufferizationState &bufferizationState,
|
||||
BufferizationStatistics *statistics) {
|
||||
if (options.copyBeforeWrite) {
|
||||
AnalysisState state(options);
|
||||
if (failed(insertTensorCopies(op, state)))
|
||||
AnalysisState analysisState(options);
|
||||
if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -396,7 +396,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||
|
||||
LogicalResult
|
||||
bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) {
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
|
||||
if (!bufferizableOp)
|
||||
@@ -412,7 +413,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
|
||||
}
|
||||
|
||||
FailureOr<BaseMemRefType> memrefType =
|
||||
bufferization::getBufferType(bbArg, options);
|
||||
bufferization::getBufferType(bbArg, options, state);
|
||||
if (failed(memrefType))
|
||||
return failure();
|
||||
newTypes.push_back(*memrefType);
|
||||
@@ -463,7 +464,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
|
||||
continue;
|
||||
}
|
||||
FailureOr<BaseMemRefType> operandBufferType =
|
||||
bufferization::getBufferType(operand, options);
|
||||
bufferization::getBufferType(operand, options, state);
|
||||
if (failed(operandBufferType))
|
||||
return failure();
|
||||
rewriter.setInsertionPointAfterValue(operand);
|
||||
|
||||
@@ -213,6 +213,7 @@ struct CallOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto callOp = cast<func::CallOp>(op);
|
||||
|
||||
@@ -255,7 +256,7 @@ struct CallOpInterface
|
||||
|
||||
// Returning a memref.
|
||||
FailureOr<BaseMemRefType> resultType =
|
||||
bufferization::getBufferType(result, options);
|
||||
bufferization::getBufferType(result, options, state);
|
||||
if (failed(resultType))
|
||||
return failure();
|
||||
resultTypes.push_back(*resultType);
|
||||
@@ -278,7 +279,7 @@ struct CallOpInterface
|
||||
|
||||
// Retrieve buffers for tensor operands.
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, opOperand.get(), options);
|
||||
getBuffer(rewriter, opOperand.get(), options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
Value buffer = *maybeBuffer;
|
||||
@@ -291,7 +292,8 @@ struct CallOpInterface
|
||||
// result type.
|
||||
FailureOr<BaseMemRefType> maybeMemRefType =
|
||||
bufferization::getBufferType(
|
||||
funcOp.getArgument(opOperand.getOperandNumber()), options);
|
||||
funcOp.getArgument(opOperand.getOperandNumber()), options,
|
||||
state);
|
||||
if (failed(maybeMemRefType))
|
||||
return failure();
|
||||
memRefType = *maybeMemRefType;
|
||||
@@ -396,6 +398,7 @@ struct FuncOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto funcOp = cast<FuncOp>(op);
|
||||
auto bbArg = cast<BlockArgument>(value);
|
||||
@@ -406,7 +409,7 @@ struct FuncOpInterface
|
||||
options);
|
||||
|
||||
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
|
||||
getBufferType(op, value, options, invocationStack);
|
||||
getBufferType(op, value, options, state, invocationStack);
|
||||
}
|
||||
|
||||
/// Rewrite function bbArgs and return values into buffer form. This function
|
||||
@@ -459,7 +462,7 @@ struct FuncOpInterface
|
||||
// 1. Bufferize every block.
|
||||
for (Block &block : funcOp.getBody())
|
||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
|
||||
options)))
|
||||
options, state)))
|
||||
return failure();
|
||||
|
||||
// 2. Bufferize the operands of the all return op.
|
||||
|
||||
@@ -1379,7 +1379,7 @@ LogicalResult bufferization::runOneShotBufferize(
|
||||
// Run One-Shot Analysis and insert buffer copies (on the tensor level)
|
||||
// only where needed. This is the default and much more efficient than
|
||||
// copy-before-write.
|
||||
if (failed(insertTensorCopies(op, options, statistics)))
|
||||
if (failed(insertTensorCopies(op, options, state, statistics)))
|
||||
return failure();
|
||||
|
||||
// If test-analysis-only is set, the IR was annotated with RaW conflict
|
||||
|
||||
@@ -584,7 +584,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
|
||||
"invalid combination of bufferization flags");
|
||||
if (!options.copyBeforeWrite) {
|
||||
if (options.noAnalysisFuncFilter.empty()) {
|
||||
if (failed(insertTensorCopies(moduleOp, options, statistics)))
|
||||
if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
|
||||
return failure();
|
||||
} else {
|
||||
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
|
||||
@@ -600,7 +600,8 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
|
||||
};
|
||||
OneShotBufferizationOptions updatedOptions(options);
|
||||
updatedOptions.opFilter.denyOperation(analysisFilterFn);
|
||||
if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
|
||||
if (failed(
|
||||
insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,28 +28,29 @@ using namespace mlir::bufferization;
|
||||
|
||||
LogicalResult mlir::bufferization::insertTensorCopies(
|
||||
Operation *op, const OneShotBufferizationOptions &options,
|
||||
const BufferizationState &bufferizationState,
|
||||
BufferizationStatistics *statistics) {
|
||||
OneShotAnalysisState state(op, options);
|
||||
OneShotAnalysisState analysisState(op, options);
|
||||
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
|
||||
// analysis depending on whether function boundary bufferization is enabled or
|
||||
// not.
|
||||
if (options.bufferizeFunctionBoundaries) {
|
||||
if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
|
||||
if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
|
||||
return failure();
|
||||
} else {
|
||||
if (failed(analyzeOp(op, state, statistics)))
|
||||
if (failed(analyzeOp(op, analysisState, statistics)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (options.testAnalysisOnly)
|
||||
return success();
|
||||
|
||||
return insertTensorCopies(op, state);
|
||||
return insertTensorCopies(op, analysisState, bufferizationState);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::bufferization::insertTensorCopies(Operation *op,
|
||||
const AnalysisState &state) {
|
||||
LogicalResult mlir::bufferization::insertTensorCopies(
|
||||
Operation *op, const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) {
|
||||
IRRewriter rewriter(op->getContext());
|
||||
|
||||
// It may be more efficient to walk in pre-order here, but the current
|
||||
@@ -62,14 +63,16 @@ mlir::bufferization::insertTensorCopies(Operation *op,
|
||||
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
|
||||
return WalkResult::skip();
|
||||
|
||||
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
|
||||
auto bufferizableOp =
|
||||
analysisState.getOptions().dynCastBufferizableOp(nestedOp);
|
||||
if (!bufferizableOp)
|
||||
return WalkResult::skip();
|
||||
|
||||
// Find inplacability conflicts and resolve them. (Typically with explicit
|
||||
// tensor copies in the form of AllocTensorOps.)
|
||||
rewriter.setInsertionPoint(nestedOp);
|
||||
if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
|
||||
if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState,
|
||||
bufferizationState)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
|
||||
@@ -24,10 +24,9 @@ using namespace mlir::bufferization;
|
||||
namespace {
|
||||
|
||||
/// Generic conversion for any DestinationStyleOpInterface on tensors.
|
||||
static LogicalResult
|
||||
bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
|
||||
DestinationStyleOpInterface op,
|
||||
const BufferizationOptions &options) {
|
||||
static LogicalResult bufferizeDestinationStyleOpInterface(
|
||||
RewriterBase &rewriter, DestinationStyleOpInterface op,
|
||||
const BufferizationOptions &options, const BufferizationState &state) {
|
||||
// Take a guard before anything else.
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(op);
|
||||
@@ -49,7 +48,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
|
||||
newInputBuffers.push_back(opOperand->get());
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
|
||||
FailureOr<Value> buffer =
|
||||
getBuffer(rewriter, opOperand->get(), options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
newInputBuffers.push_back(*buffer);
|
||||
@@ -60,7 +60,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
|
||||
for (OpResult opResult : op->getOpResults()) {
|
||||
OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
|
||||
FailureOr<Value> resultBuffer =
|
||||
getBuffer(rewriter, opOperand->get(), options);
|
||||
getBuffer(rewriter, opOperand->get(), options, state);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
newOutputBuffers.push_back(*resultBuffer);
|
||||
@@ -76,10 +76,10 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
|
||||
// new op. Since the new op does not have any tensor results, it does not
|
||||
// return anything.
|
||||
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
|
||||
OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
|
||||
op->getAttrs());
|
||||
state.addRegion();
|
||||
Operation *newOp = Operation::create(state);
|
||||
OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
|
||||
op->getAttrs());
|
||||
opState.addRegion();
|
||||
Operation *newOp = Operation::create(opState);
|
||||
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
|
||||
op->getRegion(0).getBlocks());
|
||||
|
||||
@@ -151,7 +151,7 @@ struct LinalgOpInterface
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state) const {
|
||||
return bufferizeDestinationStyleOpInterface(
|
||||
rewriter, cast<DestinationStyleOpInterface>(op), options);
|
||||
rewriter, cast<DestinationStyleOpInterface>(op), options, state);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -179,11 +179,11 @@ struct SoftmaxOpInterface
|
||||
BufferizationState &state) const {
|
||||
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
|
||||
FailureOr<Value> inputBuffer =
|
||||
getBuffer(rewriter, softmaxOp.getInput(), options);
|
||||
getBuffer(rewriter, softmaxOp.getInput(), options, state);
|
||||
if (failed(inputBuffer))
|
||||
return failure();
|
||||
FailureOr<Value> outputBuffer =
|
||||
getBuffer(rewriter, softmaxOp.getOutput(), options);
|
||||
getBuffer(rewriter, softmaxOp.getOutput(), options, state);
|
||||
if (failed(outputBuffer))
|
||||
return failure();
|
||||
rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
|
||||
|
||||
@@ -138,7 +138,8 @@ struct GlobalStoreOpInterface
|
||||
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
|
||||
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
|
||||
|
||||
auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
|
||||
auto sourceMemref =
|
||||
getBuffer(rewriter, globalStoreOp.getValue(), options, state);
|
||||
if (failed(sourceMemref)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
@@ -104,11 +104,12 @@ struct ConditionOpInterface
|
||||
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
|
||||
Value value = it.value();
|
||||
if (isa<TensorType>(value.getType())) {
|
||||
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, value, options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
||||
whileOp.getAfterArguments()[it.index()], options);
|
||||
whileOp.getAfterArguments()[it.index()], options, state);
|
||||
if (failed(resultType))
|
||||
return failure();
|
||||
Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
|
||||
@@ -196,7 +197,7 @@ struct ExecuteRegionOpInterface
|
||||
// Bufferize every block.
|
||||
for (Block &block : newOp.getRegion())
|
||||
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
|
||||
options)))
|
||||
options, state)))
|
||||
return failure();
|
||||
|
||||
// Update all uses of the old op.
|
||||
@@ -251,7 +252,7 @@ struct IfOpInterface
|
||||
newTypes.push_back(result.getType());
|
||||
continue;
|
||||
}
|
||||
auto bufferType = bufferization::getBufferType(result, options);
|
||||
auto bufferType = bufferization::getBufferType(result, options, state);
|
||||
if (failed(bufferType))
|
||||
return failure();
|
||||
newTypes.push_back(*bufferType);
|
||||
@@ -275,6 +276,7 @@ struct IfOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto ifOp = cast<scf::IfOp>(op);
|
||||
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
|
||||
@@ -290,8 +292,8 @@ struct IfOpInterface
|
||||
// True branch was already bufferized.
|
||||
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
|
||||
} else {
|
||||
auto maybeBufferType =
|
||||
bufferization::getBufferType(thenValue, options, invocationStack);
|
||||
auto maybeBufferType = bufferization::getBufferType(
|
||||
thenValue, options, state, invocationStack);
|
||||
if (failed(maybeBufferType))
|
||||
return failure();
|
||||
thenBufferType = *maybeBufferType;
|
||||
@@ -300,8 +302,8 @@ struct IfOpInterface
|
||||
// False branch was already bufferized.
|
||||
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
|
||||
} else {
|
||||
auto maybeBufferType =
|
||||
bufferization::getBufferType(elseValue, options, invocationStack);
|
||||
auto maybeBufferType = bufferization::getBufferType(
|
||||
elseValue, options, state, invocationStack);
|
||||
if (failed(maybeBufferType))
|
||||
return failure();
|
||||
elseBufferType = *maybeBufferType;
|
||||
@@ -362,7 +364,7 @@ struct IndexSwitchOpInterface
|
||||
newTypes.push_back(result.getType());
|
||||
continue;
|
||||
}
|
||||
auto bufferType = bufferization::getBufferType(result, options);
|
||||
auto bufferType = bufferization::getBufferType(result, options, state);
|
||||
if (failed(bufferType))
|
||||
return failure();
|
||||
newTypes.push_back(*bufferType);
|
||||
@@ -390,6 +392,7 @@ struct IndexSwitchOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto switchOp = cast<scf::IndexSwitchOp>(op);
|
||||
assert(value.getDefiningOp() == op && "invalid value");
|
||||
@@ -401,8 +404,8 @@ struct IndexSwitchOpInterface
|
||||
Value yieldedValue = yieldOp->getOperand(resultNum);
|
||||
if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
|
||||
return bufferType;
|
||||
auto maybeBufferType =
|
||||
bufferization::getBufferType(yieldedValue, options, invocationStack);
|
||||
auto maybeBufferType = bufferization::getBufferType(
|
||||
yieldedValue, options, state, invocationStack);
|
||||
if (failed(maybeBufferType))
|
||||
return failure();
|
||||
return maybeBufferType;
|
||||
@@ -468,12 +471,12 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
|
||||
/// given OpOperands. If an operand is not a tensor, return the original value.
|
||||
static FailureOr<SmallVector<Value>>
|
||||
getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
|
||||
const BufferizationOptions &options) {
|
||||
const BufferizationOptions &options, BufferizationState &state) {
|
||||
SmallVector<Value> result;
|
||||
for (OpOperand &opOperand : operands) {
|
||||
if (isa<TensorType>(opOperand.get().getType())) {
|
||||
FailureOr<Value> resultBuffer =
|
||||
getBuffer(rewriter, opOperand.get(), options);
|
||||
getBuffer(rewriter, opOperand.get(), options, state);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
result.push_back(*resultBuffer);
|
||||
@@ -521,10 +524,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
|
||||
/// layout map and a cast must be inserted.
|
||||
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
|
||||
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
|
||||
const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
|
||||
const BufferizationOptions &options, const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) {
|
||||
// Determine the buffer type of the init_arg.
|
||||
auto initArgBufferType =
|
||||
bufferization::getBufferType(initArg, options, invocationStack);
|
||||
bufferization::getBufferType(initArg, options, state, invocationStack);
|
||||
if (failed(initArgBufferType))
|
||||
return failure();
|
||||
|
||||
@@ -550,8 +554,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
|
||||
} else {
|
||||
// Note: This typically triggers a recursive call for the buffer type of
|
||||
// the iter_arg.
|
||||
auto maybeBufferType =
|
||||
bufferization::getBufferType(yieldedValue, options, invocationStack);
|
||||
auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
|
||||
state, invocationStack);
|
||||
if (failed(maybeBufferType))
|
||||
return failure();
|
||||
yieldedValueBufferType = *maybeBufferType;
|
||||
@@ -649,13 +653,16 @@ struct ForOpInterface
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
LogicalResult
|
||||
resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) const {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
|
||||
rewriter, analysisState, bufferizationState)))
|
||||
return failure();
|
||||
|
||||
if (state.getOptions().copyBeforeWrite)
|
||||
if (analysisState.getOptions().copyBeforeWrite)
|
||||
return success();
|
||||
|
||||
// According to the `getAliasing...` implementations, a bufferized OpResult
|
||||
@@ -683,12 +690,13 @@ struct ForOpInterface
|
||||
doesNotAliasExternalValue(
|
||||
it.value(), &forOp.getRegion(),
|
||||
/*exceptions=*/forOp.getRegionIterArg(it.index()),
|
||||
static_cast<const OneShotAnalysisState &>(state))) {
|
||||
static_cast<const OneShotAnalysisState &>(analysisState))) {
|
||||
yieldValues.push_back(it.value());
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
|
||||
rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
|
||||
bufferizationState);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
yieldValues.push_back(*alloc);
|
||||
@@ -701,6 +709,7 @@ struct ForOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto forOp = cast<scf::ForOp>(op);
|
||||
assert(getOwnerOfValue(value) == op && "invalid value");
|
||||
@@ -709,7 +718,8 @@ struct ForOpInterface
|
||||
if (auto opResult = dyn_cast<OpResult>(value)) {
|
||||
// The type of an OpResult must match the corresponding iter_arg type.
|
||||
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
|
||||
return bufferization::getBufferType(bbArg, options, invocationStack);
|
||||
return bufferization::getBufferType(bbArg, options, state,
|
||||
invocationStack);
|
||||
}
|
||||
|
||||
// Compute result/argument number.
|
||||
@@ -722,7 +732,7 @@ struct ForOpInterface
|
||||
BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
|
||||
Value initArg = forOp.getInitArgs()[resultNum];
|
||||
return computeLoopRegionIterArgBufferType(
|
||||
op, iterArg, initArg, yieldedValue, options, invocationStack);
|
||||
op, iterArg, initArg, yieldedValue, options, state, invocationStack);
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
@@ -737,7 +747,7 @@ struct ForOpInterface
|
||||
|
||||
// The new memref init_args of the loop.
|
||||
FailureOr<SmallVector<Value>> maybeInitArgs =
|
||||
getBuffers(rewriter, forOp.getInitArgsMutable(), options);
|
||||
getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
|
||||
if (failed(maybeInitArgs))
|
||||
return failure();
|
||||
SmallVector<Value> initArgs = *maybeInitArgs;
|
||||
@@ -752,7 +762,7 @@ struct ForOpInterface
|
||||
castedInitArgs.push_back(initArg);
|
||||
continue;
|
||||
}
|
||||
auto targetType = bufferization::getBufferType(result, options);
|
||||
auto targetType = bufferization::getBufferType(result, options, state);
|
||||
if (failed(targetType))
|
||||
return failure();
|
||||
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
|
||||
@@ -891,13 +901,16 @@ struct WhileOpInterface
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
LogicalResult
|
||||
resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) const {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
|
||||
rewriter, analysisState, bufferizationState)))
|
||||
return failure();
|
||||
|
||||
if (state.getOptions().copyBeforeWrite)
|
||||
if (analysisState.getOptions().copyBeforeWrite)
|
||||
return success();
|
||||
|
||||
// According to the `getAliasing...` implementations, a bufferized OpResult
|
||||
@@ -914,9 +927,10 @@ struct WhileOpInterface
|
||||
// For every yielded value, is the value equivalent to its corresponding
|
||||
// bbArg?
|
||||
DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
|
||||
whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
|
||||
DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
|
||||
whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
|
||||
whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
|
||||
DenseSet<int64_t> equivalentYieldsAfter =
|
||||
getEquivalentBuffers(whileOp.getAfterArguments(),
|
||||
whileOp.getYieldOp().getResults(), analysisState);
|
||||
|
||||
// Update "before" region.
|
||||
rewriter.setInsertionPoint(conditionOp);
|
||||
@@ -931,7 +945,8 @@ struct WhileOpInterface
|
||||
continue;
|
||||
}
|
||||
FailureOr<Value> alloc = allocateTensorForShapedValue(
|
||||
rewriter, conditionOp.getLoc(), value, state.getOptions());
|
||||
rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
|
||||
bufferizationState);
|
||||
if (failed(alloc))
|
||||
return failure();
|
||||
beforeYieldValues.push_back(*alloc);
|
||||
@@ -956,7 +971,7 @@ struct WhileOpInterface
|
||||
|
||||
// The new memref init_args of the loop.
|
||||
FailureOr<SmallVector<Value>> maybeInitArgs =
|
||||
getBuffers(rewriter, whileOp.getInitsMutable(), options);
|
||||
getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
|
||||
if (failed(maybeInitArgs))
|
||||
return failure();
|
||||
SmallVector<Value> initArgs = *maybeInitArgs;
|
||||
@@ -971,7 +986,7 @@ struct WhileOpInterface
|
||||
castedInitArgs.push_back(initArg);
|
||||
continue;
|
||||
}
|
||||
auto targetType = bufferization::getBufferType(beforeArg, options);
|
||||
auto targetType = bufferization::getBufferType(beforeArg, options, state);
|
||||
if (failed(targetType))
|
||||
return failure();
|
||||
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
|
||||
@@ -984,7 +999,7 @@ struct WhileOpInterface
|
||||
return bbArg.getType();
|
||||
// TODO: error handling
|
||||
return llvm::cast<Type>(
|
||||
*bufferization::getBufferType(bbArg, options));
|
||||
*bufferization::getBufferType(bbArg, options, state));
|
||||
}));
|
||||
|
||||
// Construct a new scf.while op with memref instead of tensor values.
|
||||
@@ -1029,6 +1044,7 @@ struct WhileOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto whileOp = cast<scf::WhileOp>(op);
|
||||
assert(getOwnerOfValue(value) == op && "invalid value");
|
||||
@@ -1041,7 +1057,7 @@ struct WhileOpInterface
|
||||
auto yieldOp = whileOp.getYieldOp();
|
||||
Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
|
||||
return computeLoopRegionIterArgBufferType(
|
||||
op, bbArg, initArg, yieldedValue, options, invocationStack);
|
||||
op, bbArg, initArg, yieldedValue, options, state, invocationStack);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1062,7 +1078,7 @@ struct WhileOpInterface
|
||||
// scf.condition was already bufferized.
|
||||
return cast<BaseMemRefType>(conditionYieldedVal.getType());
|
||||
}
|
||||
return bufferization::getBufferType(conditionYieldedVal, options,
|
||||
return bufferization::getBufferType(conditionYieldedVal, options, state,
|
||||
invocationStack);
|
||||
}
|
||||
|
||||
@@ -1161,7 +1177,8 @@ struct YieldOpInterface
|
||||
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
|
||||
Value value = it.value();
|
||||
if (isa<TensorType>(value.getType())) {
|
||||
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, value, options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
Value buffer = *maybeBuffer;
|
||||
@@ -1169,14 +1186,14 @@ struct YieldOpInterface
|
||||
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
|
||||
yieldOp->getParentOp())) {
|
||||
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
||||
yieldOp->getParentOp()->getResult(it.index()), options);
|
||||
yieldOp->getParentOp()->getResult(it.index()), options, state);
|
||||
if (failed(resultType))
|
||||
return failure();
|
||||
buffer = castBuffer(rewriter, buffer, *resultType);
|
||||
} else if (auto whileOp =
|
||||
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
|
||||
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
|
||||
whileOp.getBeforeArguments()[it.index()], options);
|
||||
whileOp.getBeforeArguments()[it.index()], options, state);
|
||||
if (failed(resultType))
|
||||
return failure();
|
||||
buffer = castBuffer(rewriter, buffer, *resultType);
|
||||
@@ -1236,7 +1253,7 @@ struct ForallOpInterface
|
||||
// Get buffers for all output operands.
|
||||
SmallVector<Value> buffers;
|
||||
for (Value out : forallOp.getOutputs()) {
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, out, options);
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
buffers.push_back(*buffer);
|
||||
@@ -1283,6 +1300,7 @@ struct ForallOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto forallOp = cast<ForallOp>(op);
|
||||
|
||||
@@ -1290,13 +1308,14 @@ struct ForallOpInterface
|
||||
// A tensor block argument has the same bufferized type as the
|
||||
// corresponding output operand.
|
||||
return bufferization::getBufferType(
|
||||
forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
|
||||
forallOp.getTiedOpOperand(bbArg)->get(), options, state,
|
||||
invocationStack);
|
||||
|
||||
// The bufferized result type is the same as the bufferized type of the
|
||||
// corresponding output operand.
|
||||
return bufferization::getBufferType(
|
||||
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
|
||||
invocationStack);
|
||||
state, invocationStack);
|
||||
}
|
||||
|
||||
bool isRepetitiveRegion(Operation *op, unsigned index) const {
|
||||
|
||||
@@ -119,7 +119,7 @@ struct AssumingYieldOpInterface
|
||||
SmallVector<Value> newResults;
|
||||
for (Value value : yieldOp.getOperands()) {
|
||||
if (isa<TensorType>(value.getType())) {
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, value, options);
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
newResults.push_back(*buffer);
|
||||
|
||||
@@ -152,8 +152,10 @@ public:
|
||||
// invalidate the results of the analysis. From now on, only small and
|
||||
// localized rewrites are allowed, such as replacing a tensor op with its
|
||||
// memref equivalent.
|
||||
if (failed(bufferization::insertTensorCopies(getOperation(),
|
||||
bufferizationOptions)))
|
||||
bufferization::BufferizationState bufferizationState;
|
||||
|
||||
if (failed(bufferization::insertTensorCopies(
|
||||
getOperation(), bufferizationOptions, bufferizationState)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
|
||||
|
||||
@@ -51,10 +51,11 @@ struct CastOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto castOp = cast<tensor::CastOp>(op);
|
||||
auto maybeSrcBufferType = bufferization::getBufferType(
|
||||
castOp.getSource(), options, invocationStack);
|
||||
castOp.getSource(), options, state, invocationStack);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
|
||||
@@ -89,13 +90,13 @@ struct CastOpInterface
|
||||
|
||||
// The result buffer still has the old (pre-cast) type.
|
||||
FailureOr<Value> resultBuffer =
|
||||
getBuffer(rewriter, castOp.getSource(), options);
|
||||
getBuffer(rewriter, castOp.getSource(), options, state);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
|
||||
// Compute the new type.
|
||||
auto resultMemRefType =
|
||||
bufferization::getBufferType(castOp.getResult(), options);
|
||||
bufferization::getBufferType(castOp.getResult(), options, state);
|
||||
if (failed(resultMemRefType))
|
||||
return failure();
|
||||
if (resultBuffer->getType() == *resultMemRefType) {
|
||||
@@ -141,10 +142,11 @@ struct CollapseShapeOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
|
||||
auto maybeSrcBufferType = bufferization::getBufferType(
|
||||
collapseShapeOp.getSrc(), options, invocationStack);
|
||||
collapseShapeOp.getSrc(), options, state, invocationStack);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||
@@ -168,7 +170,7 @@ struct CollapseShapeOpInterface
|
||||
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
|
||||
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, collapseShapeOp.getSrc(), options);
|
||||
getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
Value buffer = *maybeBuffer;
|
||||
@@ -210,7 +212,7 @@ struct CollapseShapeOpInterface
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
AnalysisState analysisState(options);
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
|
||||
rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
auto memrefType =
|
||||
@@ -252,7 +254,7 @@ struct DimOpInterface
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state) const {
|
||||
auto dimOp = cast<tensor::DimOp>(op);
|
||||
FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
|
||||
FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
|
||||
if (failed(v))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
|
||||
@@ -286,7 +288,8 @@ struct EmptyOpInterface
|
||||
|
||||
// Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
|
||||
FailureOr<Value> allocTensor = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
|
||||
rewriter, op->getLoc(), emptyOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(allocTensor))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, *allocTensor);
|
||||
@@ -317,10 +320,11 @@ struct ExpandShapeOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
||||
auto maybeSrcBufferType = bufferization::getBufferType(
|
||||
expandShapeOp.getSrc(), options, invocationStack);
|
||||
expandShapeOp.getSrc(), options, state, invocationStack);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
|
||||
@@ -338,7 +342,7 @@ struct ExpandShapeOpInterface
|
||||
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
|
||||
auto tensorResultType = expandShapeOp.getResultType();
|
||||
FailureOr<Value> buffer =
|
||||
getBuffer(rewriter, expandShapeOp.getSrc(), options);
|
||||
getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
|
||||
@@ -382,13 +386,13 @@ struct ExtractSliceOpInterface
|
||||
|
||||
// Get source buffer.
|
||||
FailureOr<Value> srcMemref =
|
||||
getBuffer(rewriter, extractSliceOp.getSource(), options);
|
||||
getBuffer(rewriter, extractSliceOp.getSource(), options, state);
|
||||
if (failed(srcMemref))
|
||||
return failure();
|
||||
|
||||
// Take a subview of the source buffer.
|
||||
auto resultMemrefType =
|
||||
bufferization::getBufferType(extractSliceOp.getResult(), options);
|
||||
auto resultMemrefType = bufferization::getBufferType(
|
||||
extractSliceOp.getResult(), options, state);
|
||||
if (failed(resultMemrefType))
|
||||
return failure();
|
||||
Value subView = rewriter.create<memref::SubViewOp>(
|
||||
@@ -401,11 +405,12 @@ struct ExtractSliceOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
|
||||
assert(value == extractSliceOp.getResult() && "invalid value");
|
||||
auto srcMemrefType = bufferization::getBufferType(
|
||||
extractSliceOp.getSource(), options, invocationStack);
|
||||
extractSliceOp.getSource(), options, state, invocationStack);
|
||||
if (failed(srcMemrefType))
|
||||
return failure();
|
||||
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
|
||||
@@ -442,7 +447,7 @@ struct ExtractOpInterface
|
||||
BufferizationState &state) const {
|
||||
auto extractOp = cast<tensor::ExtractOp>(op);
|
||||
FailureOr<Value> srcMemref =
|
||||
getBuffer(rewriter, extractOp.getTensor(), options);
|
||||
getBuffer(rewriter, extractOp.getTensor(), options, state);
|
||||
if (failed(srcMemref))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
|
||||
@@ -491,12 +496,12 @@ struct FromElementsOpInterface
|
||||
auto shape = tensorType.getShape();
|
||||
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, fromElementsOp.getResult(), options,
|
||||
rewriter, loc, fromElementsOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
FailureOr<BaseMemRefType> memrefType =
|
||||
bufferization::getBufferType(*tensorAlloc, options);
|
||||
bufferization::getBufferType(*tensorAlloc, options, state);
|
||||
if (failed(memrefType))
|
||||
return failure();
|
||||
Value buffer = rewriter.create<bufferization::ToBufferOp>(
|
||||
@@ -607,7 +612,7 @@ struct GenerateOpInterface
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, generateOp.getResult(), options,
|
||||
rewriter, loc, generateOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
@@ -633,7 +638,7 @@ struct InsertOpInterface
|
||||
BufferizationState &state) const {
|
||||
auto insertOp = cast<tensor::InsertOp>(op);
|
||||
FailureOr<Value> destMemref =
|
||||
getBuffer(rewriter, insertOp.getDest(), options);
|
||||
getBuffer(rewriter, insertOp.getDest(), options, state);
|
||||
if (failed(destMemref))
|
||||
return failure();
|
||||
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
|
||||
@@ -695,7 +700,7 @@ struct InsertSliceOpInterface
|
||||
|
||||
// Get destination buffer.
|
||||
FailureOr<Value> dstMemref =
|
||||
getBuffer(rewriter, insertSliceOp.getDest(), options);
|
||||
getBuffer(rewriter, insertSliceOp.getDest(), options, state);
|
||||
if (failed(dstMemref))
|
||||
return failure();
|
||||
|
||||
@@ -712,7 +717,7 @@ struct InsertSliceOpInterface
|
||||
// Copy tensor. If this tensor.insert_slice has a matching
|
||||
// tensor.extract_slice, the copy operation will eventually fold away.
|
||||
FailureOr<Value> srcMemref =
|
||||
getBuffer(rewriter, insertSliceOp.getSource(), options);
|
||||
getBuffer(rewriter, insertSliceOp.getSource(), options, state);
|
||||
if (failed(srcMemref))
|
||||
return failure();
|
||||
if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
|
||||
@@ -749,11 +754,12 @@ struct PadOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
// Infer memory space from the source tensor.
|
||||
auto padOp = cast<tensor::PadOp>(op);
|
||||
auto maybeSrcBufferType = bufferization::getBufferType(
|
||||
padOp.getSource(), options, invocationStack);
|
||||
padOp.getSource(), options, state, invocationStack);
|
||||
if (failed(maybeSrcBufferType))
|
||||
return failure();
|
||||
MemRefLayoutAttrInterface layout;
|
||||
@@ -797,9 +803,9 @@ struct PadOpInterface
|
||||
}
|
||||
|
||||
// Allocate a buffer for the padded result.
|
||||
FailureOr<Value> tensorAlloc =
|
||||
allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
|
||||
/*copy=*/false);
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, padOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
|
||||
@@ -846,7 +852,8 @@ struct RankOpInterface
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state) const {
|
||||
auto rankOp = cast<tensor::RankOp>(op);
|
||||
FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
|
||||
FailureOr<Value> v =
|
||||
getBuffer(rewriter, rankOp.getTensor(), options, state);
|
||||
if (failed(v))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
|
||||
@@ -885,13 +892,13 @@ struct ReshapeOpInterface
|
||||
BufferizationState &state) const {
|
||||
auto reshapeOp = cast<tensor::ReshapeOp>(op);
|
||||
FailureOr<Value> srcBuffer =
|
||||
getBuffer(rewriter, reshapeOp.getSource(), options);
|
||||
getBuffer(rewriter, reshapeOp.getSource(), options, state);
|
||||
FailureOr<Value> shapeBuffer =
|
||||
getBuffer(rewriter, reshapeOp.getShape(), options);
|
||||
getBuffer(rewriter, reshapeOp.getShape(), options, state);
|
||||
if (failed(srcBuffer) || failed(shapeBuffer))
|
||||
return failure();
|
||||
auto maybeResultMemRefType =
|
||||
bufferization::getBufferType(reshapeOp.getResult(), options);
|
||||
bufferization::getBufferType(reshapeOp.getResult(), options, state);
|
||||
if (failed(maybeResultMemRefType))
|
||||
return failure();
|
||||
|
||||
@@ -901,7 +908,7 @@ struct ReshapeOpInterface
|
||||
auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
|
||||
if (srcType && !srcType.getLayout().isIdentity()) {
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, op->getLoc(), reshapeOp.getSource(), options);
|
||||
rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
auto memrefType = MemRefType::get(
|
||||
@@ -920,11 +927,12 @@ struct ReshapeOpInterface
|
||||
|
||||
FailureOr<BaseMemRefType>
|
||||
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
|
||||
const BufferizationState &state,
|
||||
SmallVector<Value> &invocationStack) const {
|
||||
auto reshapeOp = cast<tensor::ReshapeOp>(op);
|
||||
assert(value == reshapeOp.getResult() && "unexpected value provided");
|
||||
auto maybeSourceBufferType = bufferization::getBufferType(
|
||||
reshapeOp.getSource(), options, invocationStack);
|
||||
reshapeOp.getSource(), options, state, invocationStack);
|
||||
if (failed(maybeSourceBufferType))
|
||||
return failure();
|
||||
return getMemRefTypeWithStaticIdentityLayout(
|
||||
@@ -966,11 +974,11 @@ struct ParallelInsertSliceOpInterface
|
||||
|
||||
// Get source and destination buffers.
|
||||
FailureOr<Value> destBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
|
||||
if (failed(destBuffer))
|
||||
return failure();
|
||||
FailureOr<Value> srcBuffer =
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
|
||||
getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
|
||||
if (failed(srcBuffer))
|
||||
return failure();
|
||||
|
||||
@@ -1015,8 +1023,10 @@ struct ParallelInsertSliceOpInterface
|
||||
|
||||
/// tensor.parallel_insert_slice op has implicit inplace behavior. We
|
||||
/// shouldn't create copy to resolve conflict.
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
LogicalResult
|
||||
resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) const {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -1038,7 +1048,7 @@ struct SplatOpInterface
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, splatOp.getResult(), options,
|
||||
rewriter, loc, splatOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
@@ -1097,7 +1107,7 @@ struct ConcatOpInterface
|
||||
// Allocate memory.
|
||||
Location loc = op->getLoc();
|
||||
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
|
||||
rewriter, loc, concatOp.getResult(), options,
|
||||
rewriter, loc, concatOp.getResult(), options, state,
|
||||
/*copy=*/false);
|
||||
if (failed(tensorAlloc))
|
||||
return failure();
|
||||
@@ -1147,7 +1157,7 @@ struct ConcatOpInterface
|
||||
|
||||
for (auto operand : concatOp.getInputs()) {
|
||||
// Get the buffer for the operand.
|
||||
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
|
||||
FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
|
||||
if (failed(srcBuffer))
|
||||
return failure();
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ struct TransferReadOpInterface
|
||||
auto readOp = cast<vector::TransferReadOp>(op);
|
||||
assert(isa<TensorType>(readOp.getShapedType()) &&
|
||||
"only tensor types expected");
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getBase(), options);
|
||||
FailureOr<Value> buffer =
|
||||
getBuffer(rewriter, readOp.getBase(), options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
||||
@@ -112,7 +113,7 @@ struct TransferWriteOpInterface
|
||||
|
||||
// Create a new transfer_write on buffer that doesn't have a return value.
|
||||
FailureOr<Value> resultBuffer =
|
||||
getBuffer(rewriter, writeOp.getBase(), options);
|
||||
getBuffer(rewriter, writeOp.getBase(), options, state);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
@@ -155,7 +156,8 @@ struct GatherOpInterface
|
||||
auto gatherOp = cast<vector::GatherOp>(op);
|
||||
assert(isa<TensorType>(gatherOp.getBaseType()) &&
|
||||
"only tensor types expected");
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
|
||||
FailureOr<Value> buffer =
|
||||
getBuffer(rewriter, gatherOp.getBase(), options, state);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<vector::GatherOp>(
|
||||
@@ -184,10 +186,13 @@ struct MaskOpInterface
|
||||
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
|
||||
}
|
||||
|
||||
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &state) const {
|
||||
LogicalResult
|
||||
resolveConflicts(Operation *op, RewriterBase &rewriter,
|
||||
const AnalysisState &analysisState,
|
||||
const BufferizationState &bufferizationState) const {
|
||||
auto bufferizableOp = cast<BufferizableOpInterface>(op);
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
|
||||
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
|
||||
rewriter, analysisState, bufferizationState)))
|
||||
return failure();
|
||||
|
||||
// TODO: Remove this function when vector.mask bodies can bufferize
|
||||
@@ -302,7 +307,8 @@ struct YieldOpInterface
|
||||
SmallVector<Value> newResults;
|
||||
for (Value value : yieldOp.getOperands()) {
|
||||
if (isa<TensorType>(value.getType())) {
|
||||
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
|
||||
FailureOr<Value> maybeBuffer =
|
||||
getBuffer(rewriter, value, options, state);
|
||||
if (failed(maybeBuffer))
|
||||
return failure();
|
||||
newResults.push_back(*maybeBuffer);
|
||||
|
||||
@@ -48,7 +48,11 @@ struct TestTensorCopyInsertionPass
|
||||
options.defaultMemorySpaceFn =
|
||||
[](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
|
||||
}
|
||||
if (failed(bufferization::insertTensorCopies(getOperation(), options)))
|
||||
|
||||
bufferization::BufferizationState bufferizationState;
|
||||
|
||||
if (failed(bufferization::insertTensorCopies(getOperation(), options,
|
||||
bufferizationState)))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user