diff --git a/e2e_testing/torchscript/backprop.py b/e2e_testing/torchscript/backprop.py new file mode 100644 index 000000000..4b18993ef --- /dev/null +++ b/e2e_testing/torchscript/backprop.py @@ -0,0 +1,35 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.torchscript.framework import TestUtils +from torch_mlir_e2e_test.torchscript.registry import register_test_case +from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export + +# ============================================================================== + +class SoftmaxBackwardModule(torch.nn.Module): + def __init__(self): + super().__init__() + + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, grad_output, output): + return torch.ops.aten._softmax_backward_data(grad_output, + output, + dim=1, + input_dtype=6) + + +@register_test_case(module_factory=lambda: SoftmaxBackwardModule()) +def SoftmaxBackwardModule_basic(module, tu: TestUtils): + module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4)) + diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index c2debfd94..ab2d1c569 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -224,7 +224,7 @@ def TransposeIntModule_basic(module, tu: TestUtils): class PermuteModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -258,7 +258,7 @@ def TransposeIntNegDimsModule_basic(module, tu: TestUtils): class PermuteNegativeIndexModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -374,7 +374,6 @@ def EmbeddingModule_basic(module, tu: TestUtils): class SoftmaxIntModule(torch.nn.Module): def __init__(self): super().__init__() - torch.manual_seed(0) self.softmax = torch.nn.Softmax(2) @export @@ -429,6 +428,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module): def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) + class BroadcastToModule(torch.nn.Module): def __init__(self): super().__init__() @@ -509,7 +509,7 @@ class ContiguousModule(torch.nn.Module): @register_test_case(module_factory=lambda: ContiguousModule()) def ContiguousModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 1)) - + class TensorToInt(torch.nn.Module): def __init__(self): super().__init__() @@ -522,6 +522,7 @@ class TensorToInt(torch.nn.Module): def forward(self, x): return int(x) + @register_test_case(module_factory=lambda: TensorToInt()) def TensorToInt_basic(module, tu: TestUtils): module.forward(torch.randint(10,[])) @@ -543,6 +544,7 @@ class LogSoftmaxIntModule(torch.nn.Module): def LogSoftmaxIntModule_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) + class NumToTensorModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 3bcbd0c7c..2f365f40d 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -35,6 +35,7 @@ from . import quantized_models from . import elementwise from . import type_promotion from . import type_conversion +from . import backprop from . import reduction from . import argmax from . import matmul diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 672ed7fc7..ada69a2bb 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -2785,3 +2785,20 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [ let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; } +def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ + AllowsTypeRefinement, + HasValueSemantics + ]> { + let summary = "Generated op for `aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$output, + Torch_IntType:$dim, + Torch_IntType:$input_dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)"; +} + diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index a402cc50e..f1dfd9cc2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -33,6 +33,68 @@ static int getTensorRank(Value tensor) { return tensorRank; } +static Value createAtenSum(PatternRewriter &rewriter, Location loc, + Operation *op, Value input, Value dim, + bool keepDim) { + BaseTensorType tensorType = input.getType().cast(); + Value dimList = rewriter.create( + loc, Torch::ListType::get(dim.getType()), dim); + Value keepDimCst = rewriter.create(loc, keepDim); + Value dtype = rewriter.create(loc); + SmallVector sizes; + int64_t dimInt; + if (tensorType.hasSizes()) { + ArrayRef inputShape = tensorType.getSizes(); + int64_t inputRank = inputShape.size(); + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) { + (void)rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + return nullptr; + } + sizes.append(inputShape.begin(), inputShape.end()); + sizes[dimInt] = 1; + } else { + sizes.resize(inputRank, kUnknownSize); + } + } + + Type resultType = tensorType.getWithSizesAndDtype( + sizes.size() == 0 ? Optional>() + : llvm::makeArrayRef(sizes), + tensorType.getDtype()); + Value sum = rewriter.create(loc, resultType, input, + dimList, keepDimCst, dtype); + return sum; +} + +namespace { +class DecomposeAtenSizeOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSizeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + MLIRContext *context = op.getContext(); + int64_t rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); + SmallVector sizes; + for (int i = 0; i < rank; i++) { + Value dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + sizes.push_back(rewriter.create(loc, self, dim)); + } + + Value sizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), sizes); + rewriter.replaceOp(op, sizeList); + return success(); + } +}; +} // namespace + // Decompose softmax into: exp(x) / sum(exp(x)) namespace { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { @@ -50,35 +112,13 @@ public: BaseTensorType tensorType = self.getType().cast(); if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) return rewriter.notifyMatchFailure(op, "Only support floating type"); + // exp(x) Value exp = rewriter.create(loc, tensorType, self); - // sum(exp(x)) - Value dimList = rewriter.create( - loc, Torch::ListType::get(dim.getType()), dim); - Value keepDim = rewriter.create(loc, true); - Value dtype = rewriter.create(loc); - SmallVector sizes; - int64_t dimInt; - if (tensorType.hasSizes()) { - ArrayRef inputShape = tensorType.getSizes(); - int64_t inputRank = inputShape.size(); - if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { - dimInt = toPositiveDim(dimInt, inputRank); - if (!isValidDim(dimInt, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - sizes.append(inputShape.begin(), inputShape.end()); - sizes[dimInt] = 1; - } else { - sizes.resize(inputRank, kUnknownSize); - } - } - Type resultType = tensorType.getWithSizesAndDtype( - sizes.size() == 0 ? Optional>() - : llvm::makeArrayRef(sizes), - tensorType.getDtype()); - Value sum = rewriter.create(loc, resultType, exp, - dimList, keepDim, dtype); + Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true); + if (!sum) + return failure(); // exp(x) / sum(exp(x)) Value result = rewriter.create(loc, tensorType, exp, sum); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -88,6 +128,56 @@ public: }; } // namespace +// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) => +// newGrad = gradOutput * output +// result = newGrad - output * sum(newGrad, dim)) +// +// Refer to +// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31 +namespace { +class DecomposeAten_SoftmaxBackwardDataOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + Value gradOutput = op.grad_output(); + Value output = op.output(); + Value dim = op.dim(); + + BaseTensorType tensorType = gradOutput.getType().cast(); + if (!tensorType.hasDtype() || !tensorType.getDtype().isa()) + return rewriter.notifyMatchFailure(op, "Only support floating type"); + + Value newGrad = + rewriter.create(loc, tensorType, gradOutput, output); + // temp = output * sum(newGrad, dim) + Value sum = + createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true); + if (!sum) + return failure(); + auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context)); + Value broadcastSize = + rewriter.create(loc, broadcastSizeType, output); + Value sumBroadcast = + rewriter.create(loc, tensorType, sum, broadcastSize); + Value temp = + rewriter.create(loc, tensorType, output, sumBroadcast); + + // newGrad - temp + Value alpha = + rewriter.create(loc, rewriter.getF64FloatAttr(1)); + Value sub = + rewriter.create(loc, tensorType, newGrad, temp, alpha); + + rewriter.replaceOp(op, sub); + return success(); + } +}; +} // namespace + // Decompose aten.log_softmax op into: log(softmax(x)) namespace { class DecomposeAtenLogSoftmaxIntOp @@ -177,6 +267,10 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addDynamicallyLegalOp([](AtenMatmulOp op) { int lhsRank = getTensorRank(op.self()); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 7f8caea9f..a458e7676 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -230,7 +230,7 @@ public: AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, - AtenFloorOp, AtenLog2Op>(op)) { + AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 19bbe3a1b..30e2c4e02 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -621,6 +621,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::eq.device : (Device, Device) -> (bool)") + # backprop ops + emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") + def emit_quantized_ops(torch_ir_dir: str, registry: Registry): td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td") diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 127db9612..c06dac787 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -107,3 +107,17 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> return %ret : !torch.tensor<*,f32> } + +// ---- +// CHECK-LABEL: func @torch.aten.size( +// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list { +// CHECK: %[[CST0:.*]] = torch.constant.int 0 +// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int +// CHECK: %[[CST1:.*]] = torch.constant.int 1 +// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int +// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: return %[[SIZE]] : !torch.list +func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list { + %0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list + return %0 : !torch.list +}