mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support aten.convolution_backward op
This commit adds the decomposition for the `aten.convolution_backward` and `aten.convolution_backward_overrideable` op.pull/1572/head snapshot-20221115.658
parent
f40cbd6a71
commit
92f385bd9f
|
@ -622,4 +622,8 @@ LTC_XFAIL_SET = {
|
|||
"Fill_TensorFloat32WithInt64_basic",
|
||||
"UpSampleNearest2dBackwardVec_basic",
|
||||
"UpSampleNearest2dBackwardOutputSizeNone_basic",
|
||||
"ConvolutionBackwardModule1D_basic",
|
||||
"ConvolutionBackwardModule2D_basic",
|
||||
"ConvolutionBackwardModule2DPadded_basic",
|
||||
"ConvolutionBackwardModule3D_basic"
|
||||
}
|
||||
|
|
|
@ -3977,6 +3977,75 @@ def Torch_AtenRollOp : Torch_Op<"aten.roll", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalListOfTorchIntType:$bias_sizes,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
AnyTorchListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups,
|
||||
AnyTorchListOfTorchBoolType:$output_mask
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1,
|
||||
AnyTorchTensorType:$result2
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 11, 3);
|
||||
}
|
||||
void AtenConvolutionBackwardOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 11, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
AnyTorchListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups,
|
||||
AnyTorchListOfTorchBoolType:$output_mask
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$grad_input,
|
||||
AnyTorchTensorType:$grad_weight,
|
||||
AnyTorchTensorType:$grad_bias
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 10, 3);
|
||||
}
|
||||
void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 10, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -150,6 +150,37 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
|
|||
return detail::torch_list_construct_op_binder(bind_values);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the constant bools stored in a `torch.ListConstruct`.
|
||||
struct torch_bool_list_construct_op_binder {
|
||||
SmallVectorImpl<bool> &bind_values;
|
||||
|
||||
/// Creates a matcher instance that binds the value to bvs if match succeeds.
|
||||
torch_bool_list_construct_op_binder(SmallVectorImpl<bool> &bvs)
|
||||
: bind_values(bvs) {}
|
||||
|
||||
bool match(Operation *op) {
|
||||
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
|
||||
if (!listConstruct)
|
||||
return false;
|
||||
for (Value value : listConstruct.elements()) {
|
||||
bool num;
|
||||
if (matchPattern(value, m_TorchConstantBool(&num)))
|
||||
bind_values.push_back(num);
|
||||
else
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// Matches the constant bools stored in a `torch.prim.ListConstruct`.
|
||||
inline detail::torch_bool_list_construct_op_binder
|
||||
m_TorchConstantBoolList(SmallVectorImpl<bool> &bind_values) {
|
||||
return detail::torch_bool_list_construct_op_binder(bind_values);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
/// Matches the expected tensor and dim from `torch.aten.size.int`.
|
||||
struct torch_tensor_size_int_op_binder {
|
||||
|
|
|
@ -398,6 +398,9 @@ public:
|
|||
target.addIllegalOp<AtenDivFloatOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenFloordivIntOp>();
|
||||
patterns.add<ConvertAtenBinaryOp<AtenFloordivIntOp, arith::FloorDivSIOp>>(
|
||||
typeConverter, context);
|
||||
target.addIllegalOp<AtenCeilFloatOp>();
|
||||
patterns
|
||||
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(
|
||||
|
|
|
@ -94,7 +94,6 @@ public:
|
|||
|
||||
namespace {
|
||||
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
|
||||
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
|
@ -501,11 +500,12 @@ public:
|
|||
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
|
||||
};
|
||||
|
||||
SmallVector<int64_t> paddingInts;
|
||||
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
|
||||
SmallVector<Value> paddingIntValues;
|
||||
if (!getListConstructElements(op.padding(), paddingIntValues))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant padding values");
|
||||
}
|
||||
op, "only support padding from a list construct");
|
||||
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
|
||||
paddingIntValues);
|
||||
SmallVector<int64_t> strideInts;
|
||||
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
|
@ -548,8 +548,6 @@ public:
|
|||
"invalid: groups must divide weight batch size evenly.");
|
||||
SmallVector<Value> dilationIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, dilationInts);
|
||||
SmallVector<Value> paddingIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, paddingInts);
|
||||
SmallVector<Value> strideIntValues =
|
||||
getAsConstantIntValues(rewriter, loc, strideInts);
|
||||
|
||||
|
@ -647,11 +645,8 @@ public:
|
|||
|
||||
} else {
|
||||
// Pad input
|
||||
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
|
||||
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
|
||||
paddingInts.end());
|
||||
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
|
||||
paddingIncludingNC);
|
||||
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
|
||||
|
||||
// Calculate output dims
|
||||
for (size_t i = 0; i < numSpacialDims; i++)
|
||||
|
@ -752,7 +747,6 @@ public:
|
|||
}
|
||||
|
||||
// Grouped case, use the grouped conv linalg op
|
||||
|
||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
||||
auto inShape = inType.getShape();
|
||||
|
|
|
@ -67,6 +67,40 @@ Value torch_to_linalg::getZeroPaddedTensor(
|
|||
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
||||
}
|
||||
|
||||
// Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims
|
||||
// dimensions at the beginning. The high and low padding are the same, and the
|
||||
// padding value is zero.
|
||||
Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
|
||||
int unpaddedDims) {
|
||||
assert(input.getType().isa<RankedTensorType>() &&
|
||||
"input must be RankedTensorType");
|
||||
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
|
||||
Value c0 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
|
||||
SmallVector<Value> paddingIncludingUnchanged(unpaddedDims, c0);
|
||||
paddingIncludingUnchanged.append(padding);
|
||||
assert(unpaddedDims + padding.size() == inRank &&
|
||||
"sum of unpaddedDims and padding.size() must equal to inputRank");
|
||||
for (auto pad = paddingIncludingUnchanged.begin();
|
||||
pad < paddingIncludingUnchanged.end(); pad++)
|
||||
*pad = castIntToIndex(b, loc, *pad);
|
||||
|
||||
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
|
||||
Type inputType = RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>(SmallVector<int64_t>(inRank, kUnknownSize)),
|
||||
elementType);
|
||||
|
||||
Value cf0 =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
|
||||
SmallVector<OpFoldResult> paddingValues =
|
||||
getAsOpFoldResult(paddingIncludingUnchanged);
|
||||
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
|
||||
/*high=*/paddingValues, cf0);
|
||||
}
|
||||
|
||||
Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
|
||||
Value in, Value paddingInt,
|
||||
Value dilationInt,
|
||||
|
|
|
@ -30,6 +30,13 @@ Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|||
Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<int64_t> &paddingInts);
|
||||
|
||||
// Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims
|
||||
// dimensions at the beginning. The high and low padding are the same, and the
|
||||
// padding value is zero.
|
||||
Value getDynamicZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||
SmallVectorImpl<Value> &padding,
|
||||
int unpaddedDims = 0);
|
||||
|
||||
// Helper function to caculate the output tensor dims for convolution-like ops.
|
||||
// Along each dim:
|
||||
// dim_out =
|
||||
|
|
|
@ -1067,7 +1067,7 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.convolution_overrideable to aten.convolution
|
||||
// Decompose aten.convolution_overrideable to aten.convolution op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionOverrideableOp
|
||||
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
|
||||
|
@ -1088,7 +1088,7 @@ public:
|
|||
|
||||
// Decompose aten._convolution-like to aten.convolution
|
||||
namespace {
|
||||
template<typename ConvolutionLikeOp>
|
||||
template <typename ConvolutionLikeOp>
|
||||
class DecomposeAten_ConvolutionLikeOp
|
||||
: public OpRewritePattern<ConvolutionLikeOp> {
|
||||
public:
|
||||
|
@ -1142,7 +1142,174 @@ public:
|
|||
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
|
||||
op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue,
|
||||
op.output_padding(), op.groups());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
|
||||
// op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionBackwardOverrideableOp
|
||||
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
|
||||
op, op.getResultTypes(), op.grad_output(), op.input(), op.weight(),
|
||||
none, op.stride(), op.padding(), op.dilation(), op.transposed(),
|
||||
op.output_padding(), op.groups(), op.output_mask());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionBackwardOp
|
||||
: public OpRewritePattern<AtenConvolutionBackwardOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1));
|
||||
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(2));
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getBoolAttr(false));
|
||||
|
||||
Value gradOutput = op.grad_output();
|
||||
Value input = op.input();
|
||||
Value weight = op.weight();
|
||||
auto gradRank = getTensorRank(gradOutput);
|
||||
|
||||
if (gradRank != 4)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only 2D convolutions supported.");
|
||||
|
||||
SmallVector<Value> padding;
|
||||
if (!getListConstructElements(op.padding(), padding))
|
||||
return rewriter.notifyMatchFailure(op, "padding must be a list.");
|
||||
|
||||
SmallVector<Value> strides;
|
||||
if (!getListConstructElements(op.stride(), strides))
|
||||
return rewriter.notifyMatchFailure(op, "stride must be a list.");
|
||||
for (Value stride : strides) {
|
||||
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, stride, cstOne);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
loc, cmp, "unimplemented: only strides of 1 supported.");
|
||||
}
|
||||
|
||||
SmallVector<Value> dilations;
|
||||
if (!getListConstructElements(op.dilation(), dilations))
|
||||
return rewriter.notifyMatchFailure(op, "dilation must be a list.");
|
||||
for (Value dilation : dilations) {
|
||||
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, dilation, cstOne);
|
||||
rewriter.create<Torch::RuntimeAssertOp>(
|
||||
loc, cmp, "unimplemented: only dilations of 1 supported.");
|
||||
}
|
||||
|
||||
SmallVector<bool> outMask;
|
||||
if (!matchPattern(op.output_mask(), m_TorchConstantBoolList(outMask)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant bool output_mask is supported.");
|
||||
// Support for `False` values for output mask unimplemented.
|
||||
if (!llvm::all_of(outMask, [](bool mask) { return mask; }))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: only true values for output_mask supported.");
|
||||
|
||||
bool transposed;
|
||||
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only constant transposed is supported.");
|
||||
if (transposed)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: transposed convolutions are not supported.");
|
||||
|
||||
// Rotate weight.
|
||||
SmallVector<Value> axes;
|
||||
for (auto i = 2; i < gradRank; i++) {
|
||||
axes.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
Value axesList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), axes);
|
||||
weight = rewriter.create<Torch::AtenFlipOp>(loc, weight.getType(), weight,
|
||||
axesList);
|
||||
// Calculate padding for first convolution.
|
||||
SmallVector<Value> gradInputPaddingValues;
|
||||
for (auto i = 2; i < gradRank; i++) {
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
Value outDim = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
|
||||
|
||||
// Calculate 1 + (weightDim // 2) * 2, which fixes issues with
|
||||
// even-sized weight.
|
||||
Value weightDim = rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
|
||||
weightDim =
|
||||
rewriter.create<Torch::AtenFloordivIntOp>(loc, weightDim, cstTwo);
|
||||
weightDim = rewriter.create<Torch::AtenMulIntOp>(loc, weightDim, cstTwo);
|
||||
weightDim = rewriter.create<Torch::AtenAddIntOp>(loc, weightDim, cstOne);
|
||||
Value gradOutDim =
|
||||
rewriter.create<Torch::AtenSizeIntOp>(loc, gradOutput, dim);
|
||||
|
||||
// Calculate (((outDim - 1) * stride) + weightDim - gradOutDim) // 2,
|
||||
// the padding value for this dimension. Derived from the formula at
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
Value padVal = rewriter.create<Torch::AtenSubIntOp>(loc, outDim, cstOne);
|
||||
padVal =
|
||||
rewriter.create<Torch::AtenMulIntOp>(loc, padVal, strides[i - 2]);
|
||||
padVal = rewriter.create<Torch::AtenAddIntOp>(loc, padVal, weightDim);
|
||||
padVal = rewriter.create<Torch::AtenSubIntOp>(loc, padVal, gradOutDim);
|
||||
padVal = rewriter.create<Torch::AtenFloordivIntOp>(loc, padVal, cstTwo);
|
||||
|
||||
gradInputPaddingValues.push_back(padVal);
|
||||
}
|
||||
Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
|
||||
Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, weight.getType(), weight, cstZero, cstOne);
|
||||
// Convolve grad_output with weight.
|
||||
Value gradInput = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
|
||||
op.stride(), gradInputPadding, op.dilation(), op.transposed(),
|
||||
op.output_padding(), op.groups());
|
||||
|
||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, gradOutput.getType(), gradOutput, cstZero, cstOne);
|
||||
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, input.getType(), input, cstZero, cstOne);
|
||||
// Convolve input with grad_output.
|
||||
Value gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
|
||||
loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed,
|
||||
cstNone, op.stride(), op.padding(), op.dilation(), op.transposed(),
|
||||
op.output_padding(), op.groups());
|
||||
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
|
||||
|
||||
SmallVector<Value> dimIntList{cstZero};
|
||||
for (auto i = 2; i < gradRank; i++)
|
||||
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i)));
|
||||
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||
dimIntList);
|
||||
// Sum grad_output along dim 1.
|
||||
Value gradBias = rewriter.create<Torch::AtenSumDimIntListOp>(
|
||||
loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse,
|
||||
cstNone);
|
||||
|
||||
rewriter.replaceOp(op, {gradInput, gradWeight, gradBias});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -2940,6 +3107,8 @@ public:
|
|||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
patterns.add<DecomposeAtenConvolutionBackwardOverrideableOp>(context);
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
patterns.add<DecomposeAtenSizeOp>(context);
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAtenReshapeOp>(context);
|
||||
|
@ -2989,6 +3158,8 @@ public:
|
|||
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
||||
patterns.add<DecomposeAtenConvolutionBackwardOp>(context);
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
patterns.add<DecomposeAtenConv2dOp>(context);
|
||||
target.addIllegalOp<AtenConvTranspose2dInputOp>();
|
||||
|
|
|
@ -885,7 +885,9 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
}
|
||||
|
||||
// 3 results take dtype from first operand.
|
||||
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
|
||||
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
|
||||
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
|
||||
op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
|
|
@ -6544,6 +6544,16 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
|
|||
" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.bool, %arg8: !torch.list<int>, %arg9: !torch.int, %arg10: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
|
||||
" return %arg0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
|
|
|
@ -981,6 +981,12 @@ def aten〇_convolution〇deprecated(input: List[int], weight: List[int], bias:
|
|||
def aten〇flip(self: List[int], dims: List[int]) -> List[int]:
|
||||
return self
|
||||
|
||||
def aten〇convolution_backward(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
|
||||
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)
|
||||
|
||||
def aten〇convolution_backward_overrideable(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
|
||||
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
|
||||
|
||||
def aten〇batch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
# Torch's symbolic shape analysis is a bit looser about optional
|
||||
# arguments than we are, so their batch_norm helper function works
|
||||
|
|
|
@ -350,6 +350,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
||||
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||
|
|
|
@ -15,6 +15,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
|
|||
"Conv_Transpose1dModule_basic",
|
||||
"MaxPool2dWith3dInputModule_basic",
|
||||
"MaxPool2dWithIndicesWith3dInputModule_basic",
|
||||
"ConvolutionBackwardModule1D_basic",
|
||||
"ConvolutionBackwardModule3D_basic",
|
||||
}
|
||||
|
||||
def register_all_tests():
|
||||
|
|
|
@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class SoftmaxBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -31,11 +32,12 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
||||
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class TanhBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -51,12 +53,153 @@ class TanhBackwardModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: TanhBackwardModule())
|
||||
def TanhBackward_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 3), torch.randn(3, 3))
|
||||
module.forward(tu.rand(3, 3), tu.rand(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ConvolutionBackwardModule1D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
input_vec,
|
||||
weight,
|
||||
bias_sizes=None,
|
||||
stride=[1],
|
||||
padding=[0],
|
||||
dilation=[1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule1D())
|
||||
def ConvolutionBackwardModule1D_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3),
|
||||
tu.rand(3, 3, 1))
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
input_vec,
|
||||
weight,
|
||||
bias_sizes=None,
|
||||
stride=[1, 1],
|
||||
padding=[0, 0],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2D())
|
||||
def ConvolutionBackwardModule2D_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6),
|
||||
tu.rand(2, 2, 2, 2))
|
||||
|
||||
|
||||
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
input_vec,
|
||||
weight,
|
||||
bias_sizes=None,
|
||||
stride=[1, 1],
|
||||
padding=[2, 2],
|
||||
dilation=[1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded())
|
||||
def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6),
|
||||
tu.rand(2, 2, 3, 3))
|
||||
|
||||
|
||||
class ConvolutionBackwardModule3D(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
grad_out,
|
||||
input_vec,
|
||||
weight,
|
||||
bias_sizes=None,
|
||||
stride=[1, 1, 1],
|
||||
padding=[0],
|
||||
dilation=[1, 1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1,
|
||||
output_mask=[True, True, True])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ConvolutionBackwardModule3D())
|
||||
def ConvolutionBackwardModule3D_basic(module, tu: TestUtils):
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
module.forward(tu.rand(3, 3, 3, 3, 3), tu.rand(3, 3, 3, 3, 3),
|
||||
tu.rand(3, 3, 1, 1, 1))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class GeluBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -74,7 +217,9 @@ class GeluBackwardModule(torch.nn.Module):
|
|||
def GeluBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3), tu.rand(5, 3))
|
||||
|
||||
|
||||
class LogSoftmaxBackwardModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -90,6 +235,7 @@ class LogSoftmaxBackwardModule(torch.nn.Module):
|
|||
dim=1,
|
||||
input_dtype=6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
|
||||
def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
|
|||
|
||||
|
||||
class Conv2dNoPaddingModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
|
@ -34,6 +35,7 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Conv2dBiasNoPaddingModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
|
@ -56,6 +58,7 @@ def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Conv2dWithPaddingModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
|
@ -78,6 +81,7 @@ def Conv2dWithPaddingModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
|
@ -107,6 +111,7 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
|
|||
|
||||
|
||||
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
|
@ -134,6 +139,7 @@ def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
|
|||
t = tu.rand(5, 2, 10, 20)
|
||||
module.forward(t)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Convolution1DModule(torch.nn.Module):
|
||||
|
@ -148,14 +154,15 @@ class Convolution1DModule(torch.nn.Module):
|
|||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1],
|
||||
padding=[0],
|
||||
dilation=[1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1],
|
||||
padding=[0],
|
||||
dilation=[1],
|
||||
transposed=False,
|
||||
output_padding=[0],
|
||||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Convolution1DModule())
|
||||
def Convolution1DModule_basic(module, tu: TestUtils):
|
||||
|
@ -198,14 +205,15 @@ class Convolution3DModule(torch.nn.Module):
|
|||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1, 1, 1],
|
||||
padding=[0, 0, 0],
|
||||
dilation=[1, 1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0, 0],
|
||||
groups=1)
|
||||
weight,
|
||||
bias=None,
|
||||
stride=[1, 1, 1],
|
||||
padding=[0, 0, 0],
|
||||
dilation=[1, 1, 1],
|
||||
transposed=False,
|
||||
output_padding=[0, 0, 0],
|
||||
groups=1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Convolution3DModule())
|
||||
def Convolution3DModule_basic(module, tu: TestUtils):
|
||||
|
|
Loading…
Reference in New Issue