diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index b42ba19b4..6e8ac65ea 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -621,4 +621,6 @@ LTC_XFAIL_SET = { "Fill_TensorFloat32WithFloat32_basic", "Fill_TensorFloat32WithFloat64_basic", "Fill_TensorFloat32WithInt64_basic", + "UpSampleNearest2dBackwardVec_basic", + "UpSampleNearest2dBackwardOutputSizeNone_basic", } diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index bf944ee9c..a65a985af 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -49,6 +49,10 @@ SmallVector castIntVectorToIndexVector(OpBuilder &b, Location loc, SmallVectorImpl &intValues); +SmallVector +castIndexVectorToInt64Vector(OpBuilder &b, Location loc, + SmallVectorImpl &indexValues); + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim); SmallVector getTensorSizesUntilDim(OpBuilder &b, Location loc, diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4057e9972..9495a01fe 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4893,6 +4893,32 @@ def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [ }]; } +def Torch_AtenUpsampleNearest2dBackwardVecOp : Torch_Op<"aten.upsample_nearest2d_backward.vec", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchOptionalListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$input_size, + AnyTorchOptionalListOfTorchFloatType:$scale_factors + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUpsampleNearest2dBackwardVecOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenUpsampleNearest2dBackwardVecOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index ef85aac58..3c8d036c0 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -896,6 +896,190 @@ public: }; } // namespace +static Value getGradOutputValue(OpBuilder &builder, Location loc, + Value gradOutput, Type gradOutputElemType, + Value numBatch, Value numChannel, + Value inputIndexH, Value inputIndexW, + Value kernelIndexH, Value kernelIndexW, + SmallVector &gradOutputSizeIndexValues, + SmallVector &scaleFactorsIntValues) { + Value constantOne = builder.create(loc, 1); + + Value outputIndexH = builder.create( + loc, inputIndexH, castIntToIndex(builder, loc, scaleFactorsIntValues[0])); + outputIndexH = builder.create(loc, outputIndexH, kernelIndexH); + + Value outputIndexW = builder.create( + loc, inputIndexW, castIntToIndex(builder, loc, scaleFactorsIntValues[1])); + outputIndexW = builder.create(loc, outputIndexW, kernelIndexW); + + // Handling corner cases. + Value gradOutputHMinusOne = builder.create( + loc, gradOutputSizeIndexValues[2], constantOne); + Value predH = builder.create( + loc, arith::CmpIPredicate::sle, outputIndexH, gradOutputHMinusOne); + outputIndexH = builder.create(loc, predH, outputIndexH, + gradOutputHMinusOne); + + Value gradOutputWMinusOne = builder.create( + loc, gradOutputSizeIndexValues[3], constantOne); + Value predW = builder.create( + loc, arith::CmpIPredicate::sle, outputIndexW, gradOutputWMinusOne); + outputIndexW = builder.create(loc, predW, outputIndexW, + gradOutputWMinusOne); + + Value gradOutputValue = builder.create( + loc, gradOutput, + ValueRange{numBatch, numChannel, outputIndexH, outputIndexW}); + Value constantZero = + builder.create(loc, builder.getF32FloatAttr(0.0)); + Value pred = builder.create(loc, predH, predW); + Value result = builder.create( + loc, pred, gradOutputValue, + convertScalarToDtype(builder, loc, constantZero, gradOutputElemType)); + + return result; +} + +// The implementation of the `aten.upsample_nearest2d_backward.vec` op's +// lowering is as follows: +// gradOutput: Tensor of size [n, c, oh, ow] +// outTensor: Tensor of size [n, c, ih, iw], initialized with zero +// kh = ceil(oh/ih), kw = ceil(ow/iw) +// +// for i in range(n): +// for j in range(c): +// for p in range(ih): +// for q in range(iw): +// for x in range(kh): +// for y in range(kw): +// outTensor[i, j, p, q] += gradOutput[i, j, (p*kh)+x, (q*kw)+y] +namespace { +class ConvertAtenUpsampleNearest2dBackwardVecOp + : public OpConversionPattern { + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUpsampleNearest2dBackwardVecOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value gradOutput = adaptor.grad_output(); + + Type resultType = getTypeConverter()->convertType(op.getResult().getType()); + auto gradOutputType = gradOutput.getType().cast(); + auto gradOutputRank = gradOutputType.getRank(); + Type elementType = gradOutputType.getElementType(); + + SmallVector gradOutputSizeIndexValues = + getTensorSizes(rewriter, loc, gradOutput); + SmallVector gradOutputSizeIntValues = + castIndexVectorToInt64Vector(rewriter, loc, gradOutputSizeIndexValues); + SmallVector scaleFactorsFloatValues; + + SmallVector inputSizeTorchInt; + if (!getListConstructElements(op.input_size(), inputSizeTorchInt)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the input_size is not constructed from " + "ListConstruct"); + SmallVector inputSizeIntValues; + inputSizeIntValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), inputSizeTorchInt); + + // The dimension at which the scaling starts. + unsigned hDimOffset = 2; + + if (!op.scale_factors().getType().isa()) { + SmallVector scaleFactorsTorchFloat; + if (!getListConstructElements(op.scale_factors(), scaleFactorsTorchFloat)) + return rewriter.notifyMatchFailure( + op, "unimplemented: the scale_factors is not constructed from " + "ListConstruct"); + scaleFactorsFloatValues = getTypeConvertedValues( + rewriter, loc, getTypeConverter(), scaleFactorsTorchFloat); + } else { + for (unsigned i = hDimOffset; i < gradOutputRank; i++) { + auto scaleFactorVal = rewriter.create( + loc, + convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[i], + mlir::Float32Type::get(op->getContext())), + convertScalarToDtype(rewriter, loc, inputSizeIntValues[i], + mlir::Float32Type::get(op->getContext()))); + scaleFactorsFloatValues.push_back(scaleFactorVal); + } + } + + SmallVector scaleFactorsIntValues; + for (auto v : scaleFactorsFloatValues) + scaleFactorsIntValues.push_back(convertScalarToDtype( + rewriter, loc, rewriter.create(loc, v), + mlir::IntegerType::get(op->getContext(), 64))); + + Value outTensor = createZeroInitTensor( + rewriter, loc, + castIntVectorToIndexVector(rewriter, loc, inputSizeIntValues), + elementType); + + Value kernelTensor = rewriter.create( + loc, + getAsOpFoldResult( + castIntVectorToIndexVector(rewriter, loc, scaleFactorsIntValues)), + elementType); + unsigned kernelRank = scaleFactorsIntValues.size(); + + SmallVector affineExprs; + for (unsigned i = 0; i < gradOutputRank; i++) + affineExprs.push_back(rewriter.getAffineDimExpr(i)); + + AffineMap outputMap = + AffineMap::get(gradOutputRank + kernelRank, + /*symbolCount=*/0, affineExprs, op->getContext()); + + affineExprs.clear(); + for (unsigned i = gradOutputRank; i < gradOutputRank + kernelRank; i++) + affineExprs.push_back(rewriter.getAffineDimExpr(i)); + + AffineMap kernelMap = + AffineMap::get(gradOutputRank + kernelRank, + /*symbolCount=*/0, affineExprs, op->getContext()); + + SmallVector indexingMaps{kernelMap, outputMap}; + SmallVector iteratorTypes(gradOutputRank, + getParallelIteratorTypeName()); + iteratorTypes.push_back(getReductionIteratorTypeName()); + iteratorTypes.push_back(getReductionIteratorTypeName()); + + Value finalRes = + rewriter + .create( + loc, outTensor.getType(), ValueRange{kernelTensor}, + ValueRange{outTensor}, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value n = rewriter.create(loc, 0); + Value c = rewriter.create(loc, 1); + Value ih = rewriter.create(loc, 2); + Value iw = rewriter.create(loc, 3); + Value kh = rewriter.create(loc, 4); + Value kw = rewriter.create(loc, 5); + Value accValue = getGradOutputValue( + rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh, + kw, gradOutputSizeIndexValues, scaleFactorsIntValues); + Value outputVal = args[1]; + outputVal = + rewriter.create(loc, outputVal, accValue); + b.create(loc, outputVal); + }) + ->getResult(0); + + rewriter.replaceOpWithNewOp(op, resultType, finalRes); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg:: populateIndirectDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -913,4 +1097,6 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index e291d9bce..a7d0f36a7 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -156,6 +156,15 @@ castIntVectorToIndexVector(OpBuilder &b, Location loc, return indexValues; } +SmallVector +castIndexVectorToInt64Vector(OpBuilder &b, Location loc, + SmallVectorImpl &indexValues) { + SmallVector intValues; + for (Value v : indexValues) + intValues.push_back(castIndexToInt64(b, loc, v)); + return intValues; +} + Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) { return b.createOrFold(loc, v, dim); } diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 88e4e2ed2..706609e7b 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -700,8 +700,8 @@ void TypeAnalysis::visitOperation(Operation *op, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, - AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, - AtenFillTensorOp>(op)) { + AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, + AtenUpsampleNearest2dBackwardVecOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index 406594bfe..f3d221e9a 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6076,6 +6076,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list) -> !torch.list {\n" " return %arg1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward.vec\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.list, %arg3: !torch.optional>) -> !torch.list {\n" +" return %arg2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.avg_pool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.avg_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.optional) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 661031fd0..1c43e8624 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -690,6 +690,9 @@ def aten〇max_pool2d_with_indices(self: List[int], kernel_size: List[int], stri def aten〇max_pool2d_with_indices_backward(grad_output: List[int], self: List[int], kernel_size: List[int], stride: List[int], padding: List[int], dilation: List[int], ceil_mode: bool, indices: List[int]) -> List[int]: return self +def aten〇upsample_nearest2d_backward〇vec(grad_output: List[int], output_size: Optional[List[int]], input_size: List[int], scale_factors: Optional[List[float]]) -> List[int]: + return input_size + # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. def avg_pool2d(input: List[int], kernel_size: List[int], stride: List[int], padding: List[int], ceil_mode: bool, count_include_pad: bool, divisor_override: Optional[int]): diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index fb1650081..2b7c0f289 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -407,6 +407,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") + emit("aten::upsample_nearest2d_backward.vec : (Tensor, int[]?, int[], float[]?) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 95ee23d99..cf8929d90 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3021,4 +3021,51 @@ class SingleTensorTupleReturn(torch.nn.Module): @register_test_case(module_factory=lambda: SingleTensorTupleReturn()) def SingleTensorTupleReturn_basic(module, tu: TestUtils): - module.forward(torch.randn(2, 4)) \ No newline at end of file + module.forward(torch.randn(2, 4)) + + +# ============================================================================== + + +class UpSampleNearest2dBackwardVec(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d_backward(input, + output_size=[4, 8], + input_size=[1, 1, 2, 3], + scale_factors=None) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardVec()) +def UpSampleNearest2dBackwardVec_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 4, 8)) + + +class UpSampleNearest2dBackwardOutputSizeNone(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float64, True), + ]) + def forward(self, input): + return torch.ops.aten.upsample_nearest2d_backward(input, + output_size=None, + input_size=[1, 1, 2, 3], + scale_factors=[3.0, 4.0]) + + +@register_test_case(module_factory=lambda: UpSampleNearest2dBackwardOutputSizeNone()) +def UpSampleNearest2dBackwardOutputSizeNone_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 6, 12).to(torch.float64))