diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9495a01fe..2a14ac3da 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4893,28 +4893,29 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [ }]; } -def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [ +def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`"; + let summary = "Generated op for `aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$grad_output, - AnyTorchOptionalListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$output_size, AnyTorchListOfTorchIntType:$input_size, - AnyTorchOptionalListOfTorchFloatType:$scale_factors + AnyTorchOptionalFloatType:$scales_h, + AnyTorchOptionalFloatType:$scales_w ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); + ParseResult AtenUpsampleNearest2dBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); + void AtenUpsampleNearest2dBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); } }]; } diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 3c8d036c0..5e57b8d61 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -955,13 +955,13 @@ static Value getGradOutputValue(OpBuilder &builder, Location loc, // for y in range(kw): // outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y] namespace { -class ConvertAtenUpsampleNearest2dBackwardVecOp - : public OpConversionPattern { +class ConvertAtenUpsampleNearest2dBackwardOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor, + matchAndRewrite(AtenUpsampleNearest2dBackwardOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); @@ -976,7 +976,6 @@ public: getTensorSizes(rewriter, loc, gradOutput); SmallVector gradOutputSizeIntValues = castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues); - SmallVector scaleFactorsFloatValues; SmallVector inputSizeTorchInt; if (!getListConstructElements(op.input_size(), inputSizeTorchInt)) @@ -990,24 +989,32 @@ public: // The dimension at which the scaling starts. unsigned hDimOffset = 2; - if (!op.scale_factors().getType().isa()) { - SmallVector scaleFactorsTorchFloat; - if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat)) - return rewriter.notifyMatchFailure( - op, "unimplemented: the scale_factors is not constructed from " - "ListConstruct"); - scaleFactorsFloatValues = getTypeConvertedValues( - rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat); + SmallVector scaleFactorsFloatValues; + if (!op.scales_h().getType().isa()) { + scaleFactorsFloatValues.push_back(adaptor.scales_h()); } else { - for (unsigned i = hDimOffset; i < gradOutputRank; i++) { - auto scaleFactorVal = rewriter.create( - loc, - convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i], - mlir::Float32Type::get(op->getContext())), - convertScalarToDtype(rewriter, loc, inputSizeIntValues[i], - mlir::Float32Type::get(op->getContext()))); - scaleFactorsFloatValues.push_back(scaleFactorVal); - } + auto scaleFactorVal = rewriter.create( + loc, + convertScalarToDtype(rewriter, loc, + gradOutputSizeIntValues[hDimOffset], + mlir::Float32Type::get(op->getContext())), + convertScalarToDtype(rewriter, loc, inputSizeIntValues[hDimOffset], + mlir::Float32Type::get(op->getContext()))); + scaleFactorsFloatValues.push_back(scaleFactorVal); + } + + if (!op.scales_w().getType().isa()) { + scaleFactorsFloatValues.push_back(adaptor.scales_w()); + } else { + auto scaleFactorVal = rewriter.create( + loc, + convertScalarToDtype(rewriter, loc, + gradOutputSizeIntValues[hDimOffset + 1], + mlir::Float32Type::get(op->getContext())), + convertScalarToDtype(rewriter, loc, + inputSizeIntValues[hDimOffset + 1], + mlir::Float32Type::get(op->getContext()))); + scaleFactorsFloatValues.push_back(scaleFactorVal); } SmallVector scaleFactorsIntValues; @@ -1097,6 +1104,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 706609e7b..1ee372103 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -701,7 +701,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, - AtenUpsampleNearest2dBackwardVecOp>(op)) { + AtenUpsampleNearest2dBackwardOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index f3d221e9a..b40ef2207 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6076,7 +6076,7 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 18853df3e..c3934eb5f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -690,7 +690,7 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: return self -def aten〇upsample_nearest2d_backward〇vec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]: +def aten〇upsample_nearest2d_backward(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size # TODO: This should be upstreamed. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 2b7c0f289..7af5423a2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -407,7 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") - emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)") + emit("aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index aedbdd5ff..0eee76b1b 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3010,7 +3010,30 @@ def AtenToDeviceModule_basic(module, tu: TestUtils): # ============================================================================== -class UpSampleNearest2dBackwardVec(torch.nn.Module): +class UpSampleNearest2dBackward(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float64, True), + ]) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d_backward(input, + output_size=[6, 12], + input_size=[1, 1, 2, 3], + scales_h=3.0, + scales_w=4.0) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dBackward()) +def UpSampleNearest2dBackward_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64)) + + +class UpSampleNearest2dBackwardScalesNone(torch.nn.Module): def __init__(self): super().__init__() @@ -3024,31 +3047,9 @@ class UpSampleNearest2dBackwardVec(torch.nn.Module): return torch.ops.aten.upsample_nearest2d_backward(input, output_size=[4, 8], input_size=[1, 1, 2, 3], - scale_factors=None) + scales_h=None, + scales_w=None) - -@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec()) -def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardScalesNone()) +def UpSampleNearest2dBackwardScalesNone_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 4, 8)) - - -class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module): - - def __init__(self): - super().__init__() - - @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.float64, True), - ]) - def forward(self, input): - return torch.ops.aten.upsample_nearest2d_backward(input, - output_size=None, - input_size=[1, 1, 2, 3], - scale_factors=[3.0, 4.0]) - - -@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone()) -def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils): - module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))