mirror of https://github.com/llvm/torch-mlir
conv backward stage
parent
df0a9d91dd
commit
349c2d7b48
|
@ -22,6 +22,8 @@
|
|||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
@ -783,40 +785,49 @@ public:
|
|||
auto outTy = getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<RankedTensorType>();
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
if (!inputTy || !weightTy || !outTy) {
|
||||
return op.emitError("input, weight and output must be ranked tensors");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
if (inputTy.getRank() < 3)
|
||||
return op.emitError("only input with at least 3 dims valid");
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
SmallVector<int64_t> stride;
|
||||
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const stride list unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
SmallVector<int64_t> padding;
|
||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const padding list unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
SmallVector<int64_t> dilation;
|
||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"non-const dilation list unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
SmallVector<int64_t> outputPadding;
|
||||
if (!matchPattern(op.getOutputPadding(),
|
||||
m_TorchListOfConstantInts(outputPadding))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "non-const output_padding list unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
int64_t groups;
|
||||
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-int groups unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
bool transposed;
|
||||
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) {
|
||||
return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
// Whether need to handle outputpadding
|
||||
bool needHandleOutputPadding = false;
|
||||
for (int64_t i : outputPadding) {
|
||||
|
@ -825,11 +836,13 @@ public:
|
|||
break;
|
||||
}
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
// Op validation check
|
||||
if (needHandleOutputPadding && !transposed) {
|
||||
return op->emitError(
|
||||
"output padding attr is valid only in transposed convolution");
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
assert(padding.size() == dilation.size() &&
|
||||
padding.size() == stride.size() &&
|
||||
padding.size() == static_cast<size_t>(inputTy.getRank()) - 2 &&
|
||||
|
@ -838,15 +851,18 @@ public:
|
|||
auto nSpatialDims = padding.size();
|
||||
auto nDims = inputTy.getRank();
|
||||
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
// Kernel size must be constant.
|
||||
auto weightShape = weightTy.getShape();
|
||||
for (int i = 2; i < nDims; ++i) {
|
||||
if (weightShape[i] == ShapedType::kDynamic) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant kernel size is supported");
|
||||
}
|
||||
}
|
||||
// auto weightShape = weightTy.getShape();
|
||||
// for (int i = 2; i < nDims; ++i) {
|
||||
// llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
// if (weightShape[i] == ShapedType::kDynamic) {
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "only constant kernel size is supported");
|
||||
// }
|
||||
// }
|
||||
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
Value mhloConvResult;
|
||||
if (transposed) {
|
||||
mhloConvResult = convertTransposedConv(
|
||||
|
@ -857,6 +873,7 @@ public:
|
|||
stride, padding, dilation, groups);
|
||||
}
|
||||
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
auto bias = adaptor.getBias();
|
||||
|
||||
// No bias provided
|
||||
|
@ -864,6 +881,7 @@ public:
|
|||
rewriter.replaceOp(op, mhloConvResult);
|
||||
return success();
|
||||
}
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
|
||||
// Handle bias
|
||||
if (!bias.getType().cast<RankedTensorType>()) {
|
||||
|
@ -876,6 +894,7 @@ public:
|
|||
"legalization for bias supported");
|
||||
}
|
||||
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
assert(biasTy.getRank() <= 1);
|
||||
|
||||
// Reshape and promote bias
|
||||
|
@ -886,6 +905,7 @@ public:
|
|||
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
|
||||
options.dimSizeIndexBits);
|
||||
bias = mhlo::promoteType(rewriter, bias, outTy);
|
||||
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "llvm/ADT/StringSet.h"
|
||||
#include <cstdint>
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
@ -331,6 +332,22 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenFillScalarOp
|
||||
: public OpRewritePattern<AtenFillScalarOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenFillScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
Value valTensor = createRank0Tensor(rewriter, op.getLoc(), resType, op.getValue());
|
||||
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, resType, valTensor, op.getSelf());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
|
||||
public:
|
||||
|
@ -1423,10 +1440,6 @@ public:
|
|||
if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant bool output_mask is supported.");
|
||||
// Support for `False` values for output mask unimplemented.
|
||||
if (!llvm::all_of(outMask, [](bool mask) { return mask; }))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only true values for output_mask supported.");
|
||||
|
||||
bool transposed;
|
||||
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
|
||||
|
@ -1473,34 +1486,69 @@ public:
|
|||
padVal = rewriter.create<Torch::AtenSubIntOp>(loc, padVal, gradOutDim);
|
||||
padVal = rewriter.create<Torch::AtenFloordivIntOp>(loc, padVal, cstTwo);
|
||||
|
||||
gradInputPaddingValues.push_back(padVal);
|
||||
// gradInputPaddingValues.push_back(padVal);
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
gradInputPaddingValues.push_back(constantZero);
|
||||
|
||||
}
|
||||
Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
|
||||
Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, weight.getType(), weight, cstZero, cstOne);
|
||||
// Convolve grad_output with weight.
|
||||
Value gradInput = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
|
||||
op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(),
|
||||
op.getOutputPadding(), op.getGroups());
|
||||
|
||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, gradOutput.getType(), gradOutput, cstZero, cstOne);
|
||||
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, input.getType(), input, cstZero, cstOne);
|
||||
// Convolve input with grad_output.
|
||||
Value gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed,
|
||||
cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
||||
op.getOutputPadding(), op.getGroups());
|
||||
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
|
||||
auto transposeType = [&](Type type, int64_t dim0, int64_t dim1) {
|
||||
auto baseType = type.dyn_cast<BaseTensorType>();
|
||||
SmallVector<int64_t> transposeShape =
|
||||
llvm::to_vector(baseType.getSizes());
|
||||
std::swap(transposeShape[dim0], transposeShape[dim1]);
|
||||
return baseType.getWithSizesAndDtype(
|
||||
llvm::ArrayRef(transposeShape), baseType.getOptionalDtype());
|
||||
};
|
||||
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value cstZeroTensor = createRank0Tensor(rewriter, loc, input.getType().dyn_cast<BaseTensorType>(), constantZero);
|
||||
Value gradInput = rewriter.create<AtenExpandAsOp>(
|
||||
loc, op.getResultTypes()[0], cstZeroTensor, input);
|
||||
if (outMask[0]) {
|
||||
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
|
||||
// Convolve grad_output with weight.
|
||||
gradInput = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, op.getResultTypes()[0], gradOutput, weight, cstNone,
|
||||
op.getStride(), op.getPadding(), op.getDilation(), /* transposed */cstTrue,
|
||||
op.getOutputPadding(), op.getGroups());
|
||||
|
||||
// Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
|
||||
// loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
|
||||
// Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
// loc, transposeType(weight.getType(), 0, 1), weight, cstZero, cstOne);
|
||||
|
||||
// // Convolve grad_output with weight.
|
||||
// gradInput = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
// loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
|
||||
// op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(),
|
||||
// op.getOutputPadding(), op.getGroups());
|
||||
}
|
||||
|
||||
Value gradWeight = rewriter.create<AtenExpandAsOp>(
|
||||
loc, op.getResultTypes()[1], cstZeroTensor, weight);
|
||||
if (outMask[1]) {
|
||||
// Convolve input with grad_output.
|
||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, transposeType(gradOutput.getType(), 0, 1), gradOutput, cstZero, cstOne);
|
||||
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, transposeType(input.getType(), 0, 1), input, cstZero, cstOne);
|
||||
// Convolve input with grad_output.
|
||||
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, transposeType(op.getResultTypes()[1], 0, 1), inputTransposed, gradOutputTransposed,
|
||||
cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
||||
op.getOutputPadding(), op.getGroups());
|
||||
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne);
|
||||
}
|
||||
|
||||
SmallVector<Value> dimIntList{cstZero};
|
||||
for (unsigned i = 2; i < gradRank; i++)
|
||||
for (unsigned i = 2; i < gradRank; i++) {
|
||||
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
dimIntList);
|
||||
|
@ -4271,6 +4319,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeDropoutBackwardOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSliceScatterOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenFillScalarOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
Loading…
Reference in New Issue