Remove convolution_overrideable, convolution_backward_overrideable (#1984)

The ops `aten.convolution_overrideable` and
`aten.convolution_backward_overrideable` are currently not e2e tested
in Torch-MLIR. Moreover, there is no way to add e2e tests for them
because the ops cannot be called using the CPU backend (this also
prevents adding tested dtype functions for these ops). Since these two
ops are not expected to ever appear in PyTorch traces obtained through
standard means (https://github.com/pytorch/pytorch/issues/97481),
Torch-MLIR should not have to worry about them.
pull/1988/head snapshot-20230330.793
Ramiro Leal-Cavazos 2023-03-29 15:05:56 -07:00 committed by GitHub
parent 0103c55e55
commit 42d780dde0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 3 additions and 125 deletions

View File

@ -4343,37 +4343,6 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
}];
}
def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchOptionalTensorType:$bias,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 1);
}
void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 1);
}
}];
}
def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
AllowsTypeRefinement,
HasValueSemantics,
@ -4503,40 +4472,6 @@ def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [
}];
}
def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$input,
AnyTorchTensorType:$weight,
AnyTorchListOfTorchIntType:$stride,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$dilation,
Torch_BoolType:$transposed,
AnyTorchListOfTorchIntType:$output_padding,
Torch_IntType:$groups,
AnyTorchListOfTorchBoolType:$output_mask
);
let results = (outs
AnyTorchTensorType:$grad_input,
AnyTorchTensorType:$grad_weight,
AnyTorchTensorType:$grad_bias
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 10, 3);
}
void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 10, 3);
}
}];
}
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -7061,12 +7061,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -1452,24 +1452,6 @@ public:
}
};
} // namespace
// Decompose aten.convolution_overrideable to aten.convolution op.
namespace {
class DecomposeAtenConvolutionOverrideableOp
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups());
return success();
}
};
} // namespace
// Decompose aten._convolution-like to aten.convolution
@ -1533,27 +1515,6 @@ public:
};
} // namespace
// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
// op.
namespace {
class DecomposeAtenConvolutionBackwardOverrideableOp
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
PatternRewriter &rewriter) const override {
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
op, op.getResultTypes(), op.getGradOutput(), op.getInput(), op.getWeight(),
none, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
op.getOutputPadding(), op.getGroups(), op.getOutputMask());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenConvolutionBackwardOp
: public OpRewritePattern<AtenConvolutionBackwardOp> {
@ -3926,8 +3887,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
@ -3949,8 +3908,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
addPatternIfTargetOpIsIllegal<

View File

@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();
target.addIllegalOp<AtenMaskedFillScalarOp>();
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
target.addIllegalOp<AtenSizeOp>();
target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
@ -405,7 +404,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenLayerNormOp>();
target.addIllegalOp<AtenNativeLayerNormOp>();
target.addIllegalOp<AtenNativeBatchNormOp>();
target.addIllegalOp<AtenConvolutionOverrideableOp>();
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
target.addIllegalOp<AtenConvolutionBackwardOp>();
target.addIllegalOp<AtenConv2dOp>();

View File

@ -714,8 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote the two dtypes assuming non-zero rank.
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp,
AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) {
Aten_ConvolutionOp, AtenMvOp, AtenConvTranspose2dInputOp,
AtenMseLossOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
@ -845,8 +845,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// 3 results take dtype from first operand.
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
op)) {
AtenConvolutionBackwardOp>(op)) {
auto self = operands[0]->getValue();
auto result0Knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());

View File

@ -840,9 +840,6 @@ def atenflip〡shape(self: List[int], dims: List[int]) -> List[int]:
def atenconvolution_backward〡shape(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)
def atenconvolution_backward_overrideable〡shape(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
def atenbatch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)

View File

@ -359,12 +359,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
emit(
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"