conv backward stage

tanyo/conv_backward_stage
TanyoKwok 2023-02-15 11:26:38 +08:00
parent df0a9d91dd
commit 349c2d7b48
2 changed files with 102 additions and 33 deletions

View File

@ -22,6 +22,8 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
@ -783,40 +785,49 @@ public:
auto outTy = getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
if (!inputTy || !weightTy || !outTy) {
return op.emitError("input, weight and output must be ranked tensors");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
if (inputTy.getRank() < 3)
return op.emitError("only input with at least 3 dims valid");
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
SmallVector<int64_t> stride;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(stride))) {
return rewriter.notifyMatchFailure(op,
"non-const stride list unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
SmallVector<int64_t> padding;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding))) {
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
SmallVector<int64_t> dilation;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) {
return rewriter.notifyMatchFailure(op,
"non-const dilation list unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
SmallVector<int64_t> outputPadding;
if (!matchPattern(op.getOutputPadding(),
m_TorchListOfConstantInts(outputPadding))) {
return rewriter.notifyMatchFailure(
op, "non-const output_padding list unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups))) {
return rewriter.notifyMatchFailure(op, "non-int groups unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) {
return rewriter.notifyMatchFailure(op, "non-bool transposed unsupported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
// Whether need to handle outputpadding
bool needHandleOutputPadding = false;
for (int64_t i : outputPadding) {
@ -825,11 +836,13 @@ public:
break;
}
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
// Op validation check
if (needHandleOutputPadding && !transposed) {
return op->emitError(
"output padding attr is valid only in transposed convolution");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
assert(padding.size() == dilation.size() &&
padding.size() == stride.size() &&
padding.size() == static_cast<size_t>(inputTy.getRank()) - 2 &&
@ -838,15 +851,18 @@ public:
auto nSpatialDims = padding.size();
auto nDims = inputTy.getRank();
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
// Kernel size must be constant.
auto weightShape = weightTy.getShape();
for (int i = 2; i < nDims; ++i) {
if (weightShape[i] == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op, "only constant kernel size is supported");
}
}
// auto weightShape = weightTy.getShape();
// for (int i = 2; i < nDims; ++i) {
// llvm::dbgs() << __FILE__ << __LINE__ << "\n";
// if (weightShape[i] == ShapedType::kDynamic) {
// return rewriter.notifyMatchFailure(
// op, "only constant kernel size is supported");
// }
// }
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
Value mhloConvResult;
if (transposed) {
mhloConvResult = convertTransposedConv(
@ -857,6 +873,7 @@ public:
stride, padding, dilation, groups);
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
auto bias = adaptor.getBias();
// No bias provided
@ -864,6 +881,7 @@ public:
rewriter.replaceOp(op, mhloConvResult);
return success();
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
// Handle bias
if (!bias.getType().cast<RankedTensorType>()) {
@ -876,6 +894,7 @@ public:
"legalization for bias supported");
}
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
assert(biasTy.getRank() <= 1);
// Reshape and promote bias
@ -886,6 +905,7 @@ public:
bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims,
options.dimSizeIndexBits);
bias = mhlo::promoteType(rewriter, bias, outTy);
llvm::dbgs() << __FILE__ << __LINE__ << "\n";
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(op, outTy, mhloConvResult,

View File

@ -22,6 +22,7 @@
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
@ -331,6 +332,22 @@ public:
};
} // namespace
namespace {
class DecomposeAtenFillScalarOp
: public OpRewritePattern<AtenFillScalarOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFillScalarOp op,
PatternRewriter &rewriter) const override {
auto resType = op.getType().cast<BaseTensorType>();
Value valTensor = createRank0Tensor(rewriter, op.getLoc(), resType, op.getValue());
rewriter.replaceOpWithNewOp<AtenExpandAsOp>(op, resType, valTensor, op.getSelf());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
public:
@ -1423,10 +1440,6 @@ public:
if (!matchPattern(op.getOutputMask(), m_TorchListOfConstantBools(outMask)))
return rewriter.notifyMatchFailure(
op, "only constant bool output_mask is supported.");
// Support for `False` values for output mask unimplemented.
if (!llvm::all_of(outMask, [](bool mask) { return mask; }))
return rewriter.notifyMatchFailure(
op, "unimplemented: only true values for output_mask supported.");
bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
@ -1473,34 +1486,69 @@ public:
padVal = rewriter.create<Torch::AtenSubIntOp>(loc, padVal, gradOutDim);
padVal = rewriter.create<Torch::AtenFloordivIntOp>(loc, padVal, cstTwo);
gradInputPaddingValues.push_back(padVal);
// gradInputPaddingValues.push_back(padVal);
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
gradInputPaddingValues.push_back(constantZero);
}
Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, weight.getType(), weight, cstZero, cstOne);
// Convolve grad_output with weight.
Value gradInput = rewriter.create<Torch::AtenConvolutionOp>(
loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups());
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, gradOutput.getType(), gradOutput, cstZero, cstOne);
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, input.getType(), input, cstZero, cstOne);
// Convolve input with grad_output.
Value gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed,
cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups());
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
auto transposeType = [&](Type type, int64_t dim0, int64_t dim1) {
auto baseType = type.dyn_cast<BaseTensorType>();
SmallVector<int64_t> transposeShape =
llvm::to_vector(baseType.getSizes());
std::swap(transposeShape[dim0], transposeShape[dim1]);
return baseType.getWithSizesAndDtype(
llvm::ArrayRef(transposeShape), baseType.getOptionalDtype());
};
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstZeroTensor = createRank0Tensor(rewriter, loc, input.getType().dyn_cast<BaseTensorType>(), constantZero);
Value gradInput = rewriter.create<AtenExpandAsOp>(
loc, op.getResultTypes()[0], cstZeroTensor, input);
if (outMask[0]) {
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), true);
// Convolve grad_output with weight.
gradInput = rewriter.create<Torch::AtenConvolutionOp>(
loc, op.getResultTypes()[0], gradOutput, weight, cstNone,
op.getStride(), op.getPadding(), op.getDilation(), /* transposed */cstTrue,
op.getOutputPadding(), op.getGroups());
// Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
// loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
// Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
// loc, transposeType(weight.getType(), 0, 1), weight, cstZero, cstOne);
// // Convolve grad_output with weight.
// gradInput = rewriter.create<Torch::AtenConvolutionOp>(
// loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
// op.getStride(), gradInputPadding, op.getDilation(), op.getTransposed(),
// op.getOutputPadding(), op.getGroups());
}
Value gradWeight = rewriter.create<AtenExpandAsOp>(
loc, op.getResultTypes()[1], cstZeroTensor, weight);
if (outMask[1]) {
// Convolve input with grad_output.
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposeType(gradOutput.getType(), 0, 1), gradOutput, cstZero, cstOne);
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposeType(input.getType(), 0, 1), input, cstZero, cstOne);
// Convolve input with grad_output.
gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
loc, transposeType(op.getResultTypes()[1], 0, 1), inputTransposed, gradOutputTransposed,
cstNone, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups());
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne);
}
SmallVector<Value> dimIntList{cstZero};
for (unsigned i = 2; i < gradRank; i++)
for (unsigned i = 2; i < gradRank; i++) {
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimIntList);
@ -4271,6 +4319,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeDropoutBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSliceScatterOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGroupNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFillScalarOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;