[MLIR][TORCH] Add E2E support aten.convolution_backward op

This commit adds the decomposition for the `aten.convolution_backward`
and `aten.convolution_backward_overrideable` op.
pull/1572/head snapshot-20221115.658
George Petterson 2022-11-04 03:57:29 -04:00 committed by Vivek Khandelwal
parent f40cbd6a71
commit 92f385bd9f
15 changed files with 524 additions and 35 deletions

View File

@ -622,4 +622,8 @@ LTC_XFAIL_SET = {
"Fill_TensorFloat32WithInt64_basic",
"UpSampleNearest2dBackwardVec_basic",
"UpSampleNearest2dBackwardOutputSizeNone_basic",
"ConvolutionBackwardModule1D_basic",
"ConvolutionBackwardModule2D_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule3D_basic"
}

View File

@ -3977,6 +3977,75 @@ def Torch_AtenRollOp : Torch_Op<"aten.roll", [
}];
}
def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalListOfTorchIntType:$bias_sizes,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchBoolType:$output_mask
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 11, 3);
}
void AtenConvolutionBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 11, 3);
}
}];
}
def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchBoolType:$output_mask
);
let results = (outs
AnyTorchTensorType:$grad_input,
AnyTorchTensorType:$grad_weight,
AnyTorchTensorType:$grad_bias
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 10, 3);
}
void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 10, 3);
}
}];
}
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -150,6 +150,37 @@ m_TorchConstantIntList(SmallVectorImpl<int64_t> &bind_values) {
return detail::torch_list_construct_op_binder(bind_values);
}
namespace detail {
/// Matches the constant bools stored in a `torch.ListConstruct`.
struct torch_bool_list_construct_op_binder {
SmallVectorImpl<bool> &bind_values;
/// Creates a matcher instance that binds the value to bvs if match succeeds.
torch_bool_list_construct_op_binder(SmallVectorImpl<bool> &bvs)
: bind_values(bvs) {}
bool match(Operation *op) {
auto listConstruct = dyn_cast<Torch::PrimListConstructOp>(op);
if (!listConstruct)
return false;
for (Value value : listConstruct.elements()) {
bool num;
if (matchPattern(value, m_TorchConstantBool(&num)))
bind_values.push_back(num);
else
return false;
}
return true;
}
};
} // namespace detail
/// Matches the constant bools stored in a `torch.prim.ListConstruct`.
inline detail::torch_bool_list_construct_op_binder
m_TorchConstantBoolList(SmallVectorImpl<bool> &bind_values) {
return detail::torch_bool_list_construct_op_binder(bind_values);
}
namespace detail {
/// Matches the expected tensor and dim from `torch.aten.size.int`.
struct torch_tensor_size_int_op_binder {

View File

@ -398,6 +398,9 @@ public:
target.addIllegalOp<AtenDivFloatOp>();
patterns.add<ConvertAtenBinaryOp<AtenDivFloatOp, arith::DivFOp>>(
typeConverter, context);
target.addIllegalOp<AtenFloordivIntOp>();
patterns.add<ConvertAtenBinaryOp<AtenFloordivIntOp, arith::FloorDivSIOp>>(
typeConverter, context);
target.addIllegalOp<AtenCeilFloatOp>();
patterns
.add<ConvertAtenUnaryOpToFloatMathOp<AtenCeilFloatOp, math::CeilOp>>(

View File

@ -94,7 +94,6 @@ public:
namespace {
class ConvertAtenFlipOp : public OpConversionPattern<AtenFlipOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
@ -501,11 +500,12 @@ public:
return rewriter.create<arith::IndexCastOp>(loc, intType, v);
};
SmallVector<int64_t> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
SmallVector<Value> paddingIntValues;
if (!getListConstructElements(op.padding(), paddingIntValues))
return rewriter.notifyMatchFailure(
op, "only support constant padding values");
}
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
SmallVector<int64_t> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
@ -548,8 +548,6 @@ public:
"invalid: groups must divide weight batch size evenly.");
SmallVector<Value> dilationIntValues =
getAsConstantIntValues(rewriter, loc, dilationInts);
SmallVector<Value> paddingIntValues =
getAsConstantIntValues(rewriter, loc, paddingInts);
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);
@ -647,11 +645,8 @@ public:
} else {
// Pad input
SmallVector<int64_t, 4> paddingIncludingNC = {0, 0};
paddingIncludingNC.insert(paddingIncludingNC.end(), paddingInts.begin(),
paddingInts.end());
paddedInput = torch_to_linalg::getZeroPaddedTensor(op, rewriter, input,
paddingIncludingNC);
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2);
// Calculate output dims
for (size_t i = 0; i < numSpacialDims; i++)
@ -752,7 +747,6 @@ public:
}
// Grouped case, use the grouped conv linalg op
auto expandGroups = [&](Value tensor, size_t dim) {
auto inType = tensor.getType().cast<RankedTensorType>();
auto inShape = inType.getShape();

View File

@ -67,6 +67,40 @@ Value torch_to_linalg::getZeroPaddedTensor(
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
}
// Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims
// dimensions at the beginning. The high and low padding are the same, and the
// padding value is zero.
Value torch_to_linalg::getDynamicZeroPaddedTensor(
Operation *op, OpBuilder &b, Value &input, SmallVectorImpl<Value> &padding,
int unpaddedDims) {
assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType");
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank();
Location loc = op->getLoc();
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
Value c0 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
SmallVector<Value> paddingIncludingUnchanged(unpaddedDims, c0);
paddingIncludingUnchanged.append(padding);
assert(unpaddedDims + padding.size() == inRank &&
"sum of unpaddedDims and padding.size() must equal to inputRank");
for (auto pad = paddingIncludingUnchanged.begin();
pad < paddingIncludingUnchanged.end(); pad++)
*pad = castIntToIndex(b, loc, *pad);
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
Type inputType = RankedTensorType::get(
llvm::ArrayRef<int64_t>(SmallVector<int64_t>(inRank, kUnknownSize)),
elementType);
Value cf0 =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
SmallVector<OpFoldResult> paddingValues =
getAsOpFoldResult(paddingIncludingUnchanged);
return b.create<tensor::PadOp>(loc, inputType, input, /*low=*/paddingValues,
/*high=*/paddingValues, cf0);
}
Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc,
Value in, Value paddingInt,
Value dilationInt,

View File

@ -30,6 +30,13 @@ Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &paddingInts);
// Helper function that adds dynamic padding to a tensor, ignoring unpaddedDims
// dimensions at the beginning. The high and low padding are the same, and the
// padding value is zero.
Value getDynamicZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<Value> &padding,
int unpaddedDims = 0);
// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
// dim_out =

View File

@ -1067,7 +1067,7 @@ public:
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution
// Decompose aten.convolution_overrideable to aten.convolution op.
namespace {
class DecomposeAtenConvolutionOverrideableOp
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
@ -1088,7 +1088,7 @@ public:
// Decompose aten._convolution-like to aten.convolution
namespace {
template<typename ConvolutionLikeOp>
template <typename ConvolutionLikeOp>
class DecomposeAten_ConvolutionLikeOp
: public OpRewritePattern<ConvolutionLikeOp> {
public:
@ -1142,7 +1142,174 @@ public:
op, op->getResultTypes(), op.input(), op.weight(), op.bias(),
op.stride(), op.padding(), op.dilation(), /*transposed=*/cstTrue,
op.output_padding(), op.groups());
return success();
}
};
} // namespace
// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
// op.
namespace {
class DecomposeAtenConvolutionBackwardOverrideableOp
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
PatternRewriter &rewriter) const override {
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
op, op.getResultTypes(), op.grad_output(), op.input(), op.weight(),
none, op.stride(), op.padding(), op.dilation(), op.transposed(),
op.output_padding(), op.groups(), op.output_mask());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenConvolutionBackwardOp
: public OpRewritePattern<AtenConvolutionBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
Value gradOutput = op.grad_output();
Value input = op.input();
Value weight = op.weight();
auto gradRank = getTensorRank(gradOutput);
if (gradRank != 4)
return rewriter.notifyMatchFailure(
op, "unimplemented: only 2D convolutions supported.");
SmallVector<Value> padding;
if (!getListConstructElements(op.padding(), padding))
return rewriter.notifyMatchFailure(op, "padding must be a list.");
SmallVector<Value> strides;
if (!getListConstructElements(op.stride(), strides))
return rewriter.notifyMatchFailure(op, "stride must be a list.");
for (Value stride : strides) {
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, stride, cstOne);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "unimplemented: only strides of 1 supported.");
}
SmallVector<Value> dilations;
if (!getListConstructElements(op.dilation(), dilations))
return rewriter.notifyMatchFailure(op, "dilation must be a list.");
for (Value dilation : dilations) {
Value cmp = rewriter.create<Torch::AtenEqIntOp>(loc, dilation, cstOne);
rewriter.create<Torch::RuntimeAssertOp>(
loc, cmp, "unimplemented: only dilations of 1 supported.");
}
SmallVector<bool> outMask;
if (!matchPattern(op.output_mask(), m_TorchConstantBoolList(outMask)))
return rewriter.notifyMatchFailure(
op, "only constant bool output_mask is supported.");
// Support for `False` values for output mask unimplemented.
if (!llvm::all_of(outMask, [](bool mask) { return mask; }))
return rewriter.notifyMatchFailure(
op, "unimplemented: only true values for output_mask supported.");
bool transposed;
if (!matchPattern(op.transposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(
op, "only constant transposed is supported.");
if (transposed)
return rewriter.notifyMatchFailure(
op, "unimplemented: transposed convolutions are not supported.");
// Rotate weight.
SmallVector<Value> axes;
for (auto i = 2; i < gradRank; i++) {
axes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
}
Value axesList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), axes);
weight = rewriter.create<Torch::AtenFlipOp>(loc, weight.getType(), weight,
axesList);
// Calculate padding for first convolution.
SmallVector<Value> gradInputPaddingValues;
for (auto i = 2; i < gradRank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
Value outDim = rewriter.create<Torch::AtenSizeIntOp>(loc, input, dim);
// Calculate 1 + (weightDim // 2) * 2, which fixes issues with
// even-sized weight.
Value weightDim = rewriter.create<Torch::AtenSizeIntOp>(loc, weight, dim);
weightDim =
rewriter.create<Torch::AtenFloordivIntOp>(loc, weightDim, cstTwo);
weightDim = rewriter.create<Torch::AtenMulIntOp>(loc, weightDim, cstTwo);
weightDim = rewriter.create<Torch::AtenAddIntOp>(loc, weightDim, cstOne);
Value gradOutDim =
rewriter.create<Torch::AtenSizeIntOp>(loc, gradOutput, dim);
// Calculate (((outDim - 1) * stride) + weightDim - gradOutDim) // 2,
// the padding value for this dimension. Derived from the formula at
// https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
Value padVal = rewriter.create<Torch::AtenSubIntOp>(loc, outDim, cstOne);
padVal =
rewriter.create<Torch::AtenMulIntOp>(loc, padVal, strides[i - 2]);
padVal = rewriter.create<Torch::AtenAddIntOp>(loc, padVal, weightDim);
padVal = rewriter.create<Torch::AtenSubIntOp>(loc, padVal, gradOutDim);
padVal = rewriter.create<Torch::AtenFloordivIntOp>(loc, padVal, cstTwo);
gradInputPaddingValues.push_back(padVal);
}
Value gradInputPadding = rewriter.create<Torch::PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), gradInputPaddingValues);
Value weightTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, weight.getType(), weight, cstZero, cstOne);
// Convolve grad_output with weight.
Value gradInput = rewriter.create<Torch::AtenConvolutionOp>(
loc, op.getResultTypes()[0], gradOutput, weightTransposed, cstNone,
op.stride(), gradInputPadding, op.dilation(), op.transposed(),
op.output_padding(), op.groups());
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, gradOutput.getType(), gradOutput, cstZero, cstOne);
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, input.getType(), input, cstZero, cstOne);
// Convolve input with grad_output.
Value gradWeight = rewriter.create<Torch::AtenConvolutionOp>(
loc, op.getResultTypes()[1], inputTransposed, gradOutputTransposed,
cstNone, op.stride(), op.padding(), op.dilation(), op.transposed(),
op.output_padding(), op.groups());
gradWeight = rewriter.create<Torch::AtenTransposeIntOp>(
loc, gradWeight.getType(), gradWeight, cstZero, cstOne);
SmallVector<Value> dimIntList{cstZero};
for (auto i = 2; i < gradRank; i++)
dimIntList.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
Value gradIntList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
dimIntList);
// Sum grad_output along dim 1.
Value gradBias = rewriter.create<Torch::AtenSumDimIntListOp>(
loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse,
cstNone);
rewriter.replaceOp(op, {gradInput, gradWeight, gradBias});
return success();
}
};
@ -2940,6 +3107,8 @@ public:
target.addIllegalOp<AtenWhereScalarOtherOp>();
patterns.add<DecomposeAtenWhereScalarSelfOp>(context);
target.addIllegalOp<AtenWhereScalarSelfOp>();
patterns.add<DecomposeAtenConvolutionBackwardOverrideableOp>(context);
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAtenReshapeOp>(context);
@ -2989,6 +3158,8 @@ public:
patterns.add<DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>,
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionDeprecatedOp>>(
context);
target.addIllegalOp<AtenConvolutionBackwardOp>();
patterns.add<DecomposeAtenConvolutionBackwardOp>(context);
target.addIllegalOp<AtenConv2dOp>();
patterns.add<DecomposeAtenConv2dOp>(context);
target.addIllegalOp<AtenConvTranspose2dInputOp>();

View File

@ -885,7 +885,9 @@ void TypeAnalysis::visitOperation(Operation *op,
}
// 3 results take dtype from first operand.
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp>(op)) {
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());

View File

@ -6544,6 +6544,16 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.flip\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.bool, %arg8: !torch.list<int>, %arg9: !torch.int, %arg10: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"

View File

@ -981,6 +981,12 @@ def aten_convolutiondeprecated(input: List[int], weight: List[int], bias:
def atenflip(self: List[int], dims: List[int]) -> List[int]:
return self
def atenconvolution_backward(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)
def atenconvolution_backward_overrideable(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
def atenbatch_norm(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
# Torch's symbolic shape analysis is a bit looser about optional
# arguments than we are, so their batch_norm helper function works

View File

@ -350,6 +350,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"

View File

@ -15,6 +15,8 @@ COMMON_TORCH_MLIR_LOWERING_XFAILS = {
"Conv_Transpose1dModule_basic",
"MaxPool2dWith3dInputModule_basic",
"MaxPool2dWithIndicesWith3dInputModule_basic",
"ConvolutionBackwardModule1D_basic",
"ConvolutionBackwardModule3D_basic",
}
def register_all_tests():

View File

@ -13,6 +13,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class SoftmaxBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -31,11 +32,12 @@ class SoftmaxBackwardModule(torch.nn.Module):
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))
# ==============================================================================
class TanhBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -51,12 +53,153 @@ class TanhBackwardModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TanhBackwardModule())
def TanhBackward_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 3), torch.randn(3, 3))
module.forward(tu.rand(3, 3), tu.rand(3, 3))
# ==============================================================================
class ConvolutionBackwardModule1D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
grad_out,
input_vec,
weight,
bias_sizes=None,
stride=[1],
padding=[0],
dilation=[1],
transposed=False,
output_padding=[0],
groups=1,
output_mask=[True, True, True])
@register_test_case(module_factory=lambda: ConvolutionBackwardModule1D())
def ConvolutionBackwardModule1D_basic(module, tu: TestUtils):
with torch.backends.mkldnn.flags(enabled=False):
module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3),
tu.rand(3, 3, 1))
class ConvolutionBackwardModule2D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
grad_out,
input_vec,
weight,
bias_sizes=None,
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
transposed=False,
output_padding=[0],
groups=1,
output_mask=[True, True, True])
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2D())
def ConvolutionBackwardModule2D_basic(module, tu: TestUtils):
with torch.backends.mkldnn.flags(enabled=False):
module.forward(tu.rand(2, 2, 5, 5), tu.rand(2, 2, 6, 6),
tu.rand(2, 2, 2, 2))
class ConvolutionBackwardModule2DPadded(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
grad_out,
input_vec,
weight,
bias_sizes=None,
stride=[1, 1],
padding=[2, 2],
dilation=[1, 1],
transposed=False,
output_padding=[0],
groups=1,
output_mask=[True, True, True])
@register_test_case(module_factory=lambda: ConvolutionBackwardModule2DPadded())
def ConvolutionBackwardModule2DPadded_basic(module, tu: TestUtils):
with torch.backends.mkldnn.flags(enabled=False):
module.forward(tu.rand(2, 2, 8, 8), tu.rand(2, 2, 6, 6),
tu.rand(2, 2, 3, 3))
class ConvolutionBackwardModule3D(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1, -1], torch.float32, True),
])
def forward(self, grad_out, input_vec, weight):
return torch.ops.aten.convolution_backward(
grad_out,
input_vec,
weight,
bias_sizes=None,
stride=[1, 1, 1],
padding=[0],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0],
groups=1,
output_mask=[True, True, True])
@register_test_case(module_factory=lambda: ConvolutionBackwardModule3D())
def ConvolutionBackwardModule3D_basic(module, tu: TestUtils):
with torch.backends.mkldnn.flags(enabled=False):
module.forward(tu.rand(3, 3, 3, 3, 3), tu.rand(3, 3, 3, 3, 3),
tu.rand(3, 3, 1, 1, 1))
# ==============================================================================
class GeluBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -74,7 +217,9 @@ class GeluBackwardModule(torch.nn.Module):
def GeluBackwardModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3), tu.rand(5, 3))
class LogSoftmaxBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -90,6 +235,7 @@ class LogSoftmaxBackwardModule(torch.nn.Module):
dim=1,
input_dtype=6)
@register_test_case(module_factory=lambda: LogSoftmaxBackwardModule())
def LogSoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
module.forward(tu.rand(3, 2, 4), tu.rand(3, 2, 4))

View File

@ -12,6 +12,7 @@ from torch_mlir_e2e_test.annotations import annotate_args, export
class Conv2dNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
@ -34,6 +35,7 @@ def Conv2dNoPaddingModule_basic(module, tu: TestUtils):
class Conv2dBiasNoPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
@ -56,6 +58,7 @@ def Conv2dBiasNoPaddingModule_basic(module, tu: TestUtils):
class Conv2dWithPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
@ -78,6 +81,7 @@ def Conv2dWithPaddingModule_basic(module, tu: TestUtils):
class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
@ -107,6 +111,7 @@ def Conv2dWithPaddingDilationStrideModule_basic(module, tu: TestUtils):
class Conv2dWithPaddingDilationStrideStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
@ -134,6 +139,7 @@ def Conv2dWithPaddingDilationStrideStaticModule_basic(module, tu: TestUtils):
t = tu.rand(5, 2, 10, 20)
module.forward(t)
# ==============================================================================
class Convolution1DModule(torch.nn.Module):
@ -148,14 +154,15 @@ class Convolution1DModule(torch.nn.Module):
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1],
padding=[0],
dilation=[1],
transposed=False,
output_padding=[0],
groups=1)
weight,
bias=None,
stride=[1],
padding=[0],
dilation=[1],
transposed=False,
output_padding=[0],
groups=1)
@register_test_case(module_factory=lambda: Convolution1DModule())
def Convolution1DModule_basic(module, tu: TestUtils):
@ -198,14 +205,15 @@ class Convolution3DModule(torch.nn.Module):
])
def forward(self, inputVec, weight):
return torch.ops.aten.convolution(inputVec,
weight,
bias=None,
stride=[1, 1, 1],
padding=[0, 0, 0],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0, 0, 0],
groups=1)
weight,
bias=None,
stride=[1, 1, 1],
padding=[0, 0, 0],
dilation=[1, 1, 1],
transposed=False,
output_padding=[0, 0, 0],
groups=1)
@register_test_case(module_factory=lambda: Convolution3DModule())
def Convolution3DModule_basic(module, tu: TestUtils):