Add e2e support for aten._softmax_backward_data.

Decompose aten._softmax_backward_data into aten math ops. Also decompose
`aten.size` to facilitate decomposing _softmax_backward_data.
pull/407/head snapshot-20211109.73
Yi Zhang 2021-11-08 10:56:40 -05:00 committed by Prashant Kumar
parent 05c4dd8e39
commit 3bd9d2a4c7
8 changed files with 197 additions and 31 deletions

View File

@ -0,0 +1,35 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class SoftmaxBackwardModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, grad_output, output):
return torch.ops.aten._softmax_backward_data(grad_output,
output,
dim=1,
input_dtype=6)
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))

View File

@ -374,7 +374,6 @@ def EmbeddingModule_basic(module, tu: TestUtils):
class SoftmaxIntModule(torch.nn.Module): class SoftmaxIntModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
torch.manual_seed(0)
self.softmax = torch.nn.Softmax(2) self.softmax = torch.nn.Softmax(2)
@export @export
@ -429,6 +428,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double()) module.forward(torch.randn(3, 2, 4).double())
class BroadcastToModule(torch.nn.Module): class BroadcastToModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -522,6 +522,7 @@ class TensorToInt(torch.nn.Module):
def forward(self, x): def forward(self, x):
return int(x) return int(x)
@register_test_case(module_factory=lambda: TensorToInt()) @register_test_case(module_factory=lambda: TensorToInt())
def TensorToInt_basic(module, tu: TestUtils): def TensorToInt_basic(module, tu: TestUtils):
module.forward(torch.randint(10,[])) module.forward(torch.randint(10,[]))
@ -543,6 +544,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
def LogSoftmaxIntModule_basic(module, tu: TestUtils): def LogSoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double()) module.forward(torch.randn(3, 2, 4).double())
class NumToTensorModule(torch.nn.Module): class NumToTensorModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -35,6 +35,7 @@ from . import quantized_models
from . import elementwise from . import elementwise
from . import type_promotion from . import type_promotion
from . import type_conversion from . import type_conversion
from . import backprop
from . import reduction from . import reduction
from . import argmax from . import argmax
from . import matmul from . import matmul

View File

@ -2785,3 +2785,20 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
} }
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$output,
Torch_IntType:$dim,
Torch_IntType:$input_dtype
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
}

View File

@ -33,6 +33,68 @@ static int getTensorRank(Value tensor) {
return tensorRank; return tensorRank;
} }
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
ArrayRef<int64_t> inputShape = tensorType.getSizes();
int64_t inputRank = inputShape.size();
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank)) {
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
return nullptr;
}
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = 1;
} else {
sizes.resize(inputRank, kUnknownSize);
}
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
dimList, keepDimCst, dtype);
return sum;
}
namespace {
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSizeOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
MLIRContext *context = op.getContext();
int64_t rank = getTensorRank(self);
if (rank < 0)
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
SmallVector<Value> sizes;
for (int i = 0; i < rank; i++) {
Value dim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i));
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
}
Value sizeList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
rewriter.replaceOp(op, sizeList);
return success();
}
};
} // namespace
// Decompose softmax into: exp(x) / sum(exp(x)) // Decompose softmax into: exp(x) / sum(exp(x))
namespace { namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
@ -50,35 +112,13 @@ public:
BaseTensorType tensorType = self.getType().cast<BaseTensorType>(); BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
// exp(x) // exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self); Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
// sum(exp(x)) // sum(exp(x))
Value dimList = rewriter.create<PrimListConstructOp>( Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
loc, Torch::ListType::get(dim.getType()), dim); if (!sum)
Value keepDim = rewriter.create<ConstantBoolOp>(loc, true); return failure();
Value dtype = rewriter.create<ConstantNoneOp>(loc);
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
ArrayRef<int64_t> inputShape = tensorType.getSizes();
int64_t inputRank = inputShape.size();
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = 1;
} else {
sizes.resize(inputRank, kUnknownSize);
}
}
Type resultType = tensorType.getWithSizesAndDtype(
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, exp,
dimList, keepDim, dtype);
// exp(x) / sum(exp(x)) // exp(x) / sum(exp(x))
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum); Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(), rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@ -88,6 +128,56 @@ public:
}; };
} // namespace } // namespace
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
// newGrad = gradOutput * output
// result = newGrad - output * sum(newGrad, dim))
//
// Refer to
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
namespace {
class DecomposeAten_SoftmaxBackwardDataOp
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
MLIRContext *context = op.getContext();
Value gradOutput = op.grad_output();
Value output = op.output();
Value dim = op.dim();
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
Value newGrad =
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
// temp = output * sum(newGrad, dim)
Value sum =
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
if (!sum)
return failure();
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
Value broadcastSize =
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
Value sumBroadcast =
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
Value temp =
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
// newGrad - temp
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
Value sub =
rewriter.create<AtenSubTensorOp>(loc, tensorType, newGrad, temp, alpha);
rewriter.replaceOp(op, sub);
return success();
}
};
} // namespace
// Decompose aten.log_softmax op into: log(softmax(x)) // Decompose aten.log_softmax op into: log(softmax(x))
namespace { namespace {
class DecomposeAtenLogSoftmaxIntOp class DecomposeAtenLogSoftmaxIntOp
@ -177,6 +267,10 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenLogSoftmaxIntOp>(); target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenExpandOp>(context); patterns.add<DecomposeAtenExpandOp>(context);
target.addIllegalOp<AtenExpandOp>(); target.addIllegalOp<AtenExpandOp>();
patterns.add<DecomposeAtenSizeOp>(context);
target.addIllegalOp<AtenSizeOp>();
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
patterns.add<DecomposeAtenMatmulOp>(context); patterns.add<DecomposeAtenMatmulOp>(context);
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) { target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
int lhsRank = getTensorRank(op.self()); int lhsRank = getTensorRank(op.self());

View File

@ -230,7 +230,7 @@ public:
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
AtenFloorOp, AtenLog2Op>(op)) { AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -621,6 +621,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::div : (Scalar, Scalar) -> (float)") emit("aten::div : (Scalar, Scalar) -> (float)")
emit("aten::eq.device : (Device, Device) -> (bool)") emit("aten::eq.device : (Device, Device) -> (bool)")
# backprop ops
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
def emit_quantized_ops(torch_ir_dir: str, registry: Registry): def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td") td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")

View File

@ -107,3 +107,17 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32> %ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
return %ret : !torch.tensor<*,f32> return %ret : !torch.tensor<*,f32>
} }
// ----
// CHECK-LABEL: func @torch.aten.size(
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
// CHECK: return %[[SIZE]] : !torch.list<!torch.int>
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
return %0 : !torch.list<!torch.int>
}