mirror of https://github.com/llvm/torch-mlir
Remove convolution_overrideable, convolution_backward_overrideable (#1984)
The ops `aten.convolution_overrideable` and `aten.convolution_backward_overrideable` are currently not e2e tested in Torch-MLIR. Moreover, there is no way to add e2e tests for them because the ops cannot be called using the CPU backend (this also prevents adding tested dtype functions for these ops). Since these two ops are not expected to ever appear in PyTorch traces obtained through standard means (https://github.com/pytorch/pytorch/issues/97481), Torch-MLIR should not have to worry about them.pull/1988/head snapshot-20230330.793
parent
0103c55e55
commit
42d780dde0
|
@ -4343,37 +4343,6 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideable", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
AnyTorchListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 9, 1);
|
||||
}
|
||||
void AtenConvolutionOverrideableOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 9, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_ConvolutionOp : Torch_Op<"aten._convolution", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -4503,40 +4472,6 @@ def Torch_AtenConvolutionBackwardOp : Torch_Op<"aten.convolution_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenConvolutionBackwardOverrideableOp : Torch_Op<"aten.convolution_backward_overrideable", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchListOfTorchIntType:$padding,
|
||||
AnyTorchListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$transposed,
|
||||
AnyTorchListOfTorchIntType:$output_padding,
|
||||
Torch_IntType:$groups,
|
||||
AnyTorchListOfTorchBoolType:$output_mask
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$grad_input,
|
||||
AnyTorchTensorType:$grad_weight,
|
||||
AnyTorchTensorType:$grad_bias
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenConvolutionBackwardOverrideableOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 10, 3);
|
||||
}
|
||||
void AtenConvolutionBackwardOverrideableOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 10, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFlipOp : Torch_Op<"aten.flip", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -7061,12 +7061,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.convolution_backward_overrideable\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.bool, %arg7: !torch.list<int>, %arg8: !torch.int, %arg9: !torch.list<bool>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.derefine %none : !torch.none to !torch.optional<list<int>>\n"
|
||||
" %1 = call @__torch__.torch.jit._shape_functions.conv_backwards(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %1 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.batch_norm\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<int>>, %arg3: !torch.optional<list<int>>, %arg4: !torch.optional<list<int>>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list<int>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.optional<list<int>>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -1452,24 +1452,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
// Decompose aten.convolution_overrideable to aten.convolution op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionOverrideableOp
|
||||
: public OpRewritePattern<AtenConvolutionOverrideableOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvolutionOverrideableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionOp>(
|
||||
op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(),
|
||||
op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
||||
op.getOutputPadding(), op.getGroups());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten._convolution-like to aten.convolution
|
||||
|
@ -1533,27 +1515,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.convolution_backward_overrideable to aten.convolution_backward
|
||||
// op.
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionBackwardOverrideableOp
|
||||
: public OpRewritePattern<AtenConvolutionBackwardOverrideableOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenConvolutionBackwardOverrideableOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value none = rewriter.create<ConstantNoneOp>(op->getLoc());
|
||||
rewriter.replaceOpWithNewOp<AtenConvolutionBackwardOp>(
|
||||
op, op.getResultTypes(), op.getGradOutput(), op.getInput(), op.getWeight(),
|
||||
none, op.getStride(), op.getPadding(), op.getDilation(), op.getTransposed(),
|
||||
op.getOutputPadding(), op.getGroups(), op.getOutputMask());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenConvolutionBackwardOp
|
||||
: public OpRewritePattern<AtenConvolutionBackwardOp> {
|
||||
|
@ -3926,8 +3887,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenConvolutionBackwardOverrideableOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_SoftmaxBackwardDataOp>(
|
||||
|
@ -3949,8 +3908,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeLayerNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNativeBatchNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvolutionOverrideableOp>(
|
||||
patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAten_ConvolutionLikeOp<Aten_ConvolutionOp>>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
|
|
|
@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenWhereScalarOtherOp>();
|
||||
target.addIllegalOp<AtenWhereScalarSelfOp>();
|
||||
target.addIllegalOp<AtenMaskedFillScalarOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOverrideableOp>();
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
target.addIllegalOp<AtenReshapeOp>();
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
|
@ -405,7 +404,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeLayerNormOp>();
|
||||
target.addIllegalOp<AtenNativeBatchNormOp>();
|
||||
target.addIllegalOp<AtenConvolutionOverrideableOp>();
|
||||
target.addIllegalOp<Aten_ConvolutionOp, Aten_ConvolutionDeprecatedOp>();
|
||||
target.addIllegalOp<AtenConvolutionBackwardOp>();
|
||||
target.addIllegalOp<AtenConv2dOp>();
|
||||
|
|
|
@ -714,8 +714,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Promote the two dtypes assuming non-zero rank.
|
||||
if (isa<AtenMmOp, AtenBmmOp, AtenMatmulOp, AtenConv2dOp, AtenConvolutionOp,
|
||||
Aten_ConvolutionOp, AtenMvOp, AtenConvolutionOverrideableOp,
|
||||
AtenConvTranspose2dInputOp, AtenMseLossOp>(op)) {
|
||||
Aten_ConvolutionOp, AtenMvOp, AtenConvTranspose2dInputOp,
|
||||
AtenMseLossOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank(
|
||||
|
@ -845,8 +845,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// 3 results take dtype from first operand.
|
||||
if (isa<AtenNativeLayerNormOp, AtenNativeBatchNormOp,
|
||||
AtenConvolutionBackwardOp, AtenConvolutionBackwardOverrideableOp>(
|
||||
op)) {
|
||||
AtenConvolutionBackwardOp>(op)) {
|
||||
auto self = operands[0]->getValue();
|
||||
auto result0Knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
|
|
|
@ -840,9 +840,6 @@ def aten〇flip〡shape(self: List[int], dims: List[int]) -> List[int]:
|
|||
def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int], weight: List[int], bias_sizes: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
|
||||
return upstream_shape_functions.conv_backwards(grad_output, input, weight, bias_sizes)
|
||||
|
||||
def aten〇convolution_backward_overrideable〡shape(grad_output: List[int], input: List[int], weight: List[int], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, output_mask: List[bool]) -> Tuple[List[int], List[int], List[int]]:
|
||||
return upstream_shape_functions.conv_backwards(grad_output, input, weight, None)
|
||||
|
||||
def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]:
|
||||
return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)
|
||||
|
||||
|
|
|
@ -359,12 +359,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
||||
emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)")
|
||||
emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::convolution_overrideable : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)")
|
||||
emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)")
|
||||
emit("aten::roll : (Tensor, int[], int[]) -> (Tensor)"),
|
||||
emit("aten::convolution_backward : (Tensor, Tensor, Tensor, int[]?, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::convolution_backward_overrideable : (Tensor, Tensor, Tensor, int[], int[], int[], bool, int[], int, bool[]) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::flip : (Tensor, int[]) -> (Tensor)")
|
||||
emit(
|
||||
"aten::native_batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float) -> (Tensor, Tensor, Tensor)"
|
||||
|
|
Loading…
Reference in New Issue