mirror of https://github.com/llvm/torch-mlir
Add e2e support for aten._softmax_backward_data.
Decompose aten._softmax_backward_data into aten math ops. Also decompose `aten.size` to facilitate decomposing _softmax_backward_data.pull/407/head snapshot-20211109.73
parent
05c4dd8e39
commit
3bd9d2a4c7
|
@ -0,0 +1,35 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SoftmaxBackwardModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._softmax_backward_data(grad_output,
|
||||
output,
|
||||
dim=1,
|
||||
input_dtype=6)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
||||
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||
|
|
@ -224,7 +224,7 @@ def TransposeIntModule_basic(module, tu: TestUtils):
|
|||
class PermuteModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -258,7 +258,7 @@ def TransposeIntNegDimsModule_basic(module, tu: TestUtils):
|
|||
class PermuteNegativeIndexModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
|
@ -374,7 +374,6 @@ def EmbeddingModule_basic(module, tu: TestUtils):
|
|||
class SoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
torch.manual_seed(0)
|
||||
self.softmax = torch.nn.Softmax(2)
|
||||
|
||||
@export
|
||||
|
@ -429,6 +428,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
|
||||
class BroadcastToModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -509,7 +509,7 @@ class ContiguousModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ContiguousModule())
|
||||
def ContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1))
|
||||
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -522,6 +522,7 @@ class TensorToInt(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return int(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]))
|
||||
|
@ -543,6 +544,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
|
||||
class NumToTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -35,6 +35,7 @@ from . import quantized_models
|
|||
from . import elementwise
|
||||
from . import type_promotion
|
||||
from . import type_conversion
|
||||
from . import backprop
|
||||
from . import reduction
|
||||
from . import argmax
|
||||
from . import matmul
|
||||
|
|
|
@ -2785,3 +2785,20 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
|
|||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$grad_output,
|
||||
AnyTorchTensorType:$output,
|
||||
Torch_IntType:$dim,
|
||||
Torch_IntType:$input_dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
|
|
|
@ -33,6 +33,68 @@ static int getTensorRank(Value tensor) {
|
|||
return tensorRank;
|
||||
}
|
||||
|
||||
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(dim.getType()), dim);
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
SmallVector<int64_t> sizes;
|
||||
int64_t dimInt;
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
||||
int64_t inputRank = inputShape.size();
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank)) {
|
||||
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
return nullptr;
|
||||
}
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[dimInt] = 1;
|
||||
} else {
|
||||
sizes.resize(inputRank, kUnknownSize);
|
||||
}
|
||||
}
|
||||
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
tensorType.getDtype());
|
||||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
|
||||
dimList, keepDimCst, dtype);
|
||||
return sum;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSizeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
MLIRContext *context = op.getContext();
|
||||
int64_t rank = getTensorRank(self);
|
||||
if (rank < 0)
|
||||
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
||||
SmallVector<Value> sizes;
|
||||
for (int i = 0; i < rank; i++) {
|
||||
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i));
|
||||
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
||||
}
|
||||
|
||||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
||||
rewriter.replaceOp(op, sizeList);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||
namespace {
|
||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||
|
@ -50,35 +112,13 @@ public:
|
|||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||
|
||||
// sum(exp(x))
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(dim.getType()), dim);
|
||||
Value keepDim = rewriter.create<ConstantBoolOp>(loc, true);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
SmallVector<int64_t> sizes;
|
||||
int64_t dimInt;
|
||||
if (tensorType.hasSizes()) {
|
||||
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
||||
int64_t inputRank = inputShape.size();
|
||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||
dimInt = toPositiveDim(dimInt, inputRank);
|
||||
if (!isValidDim(dimInt, inputRank))
|
||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[dimInt] = 1;
|
||||
} else {
|
||||
sizes.resize(inputRank, kUnknownSize);
|
||||
}
|
||||
}
|
||||
Type resultType = tensorType.getWithSizesAndDtype(
|
||||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
tensorType.getDtype());
|
||||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, exp,
|
||||
dimList, keepDim, dtype);
|
||||
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return failure();
|
||||
// exp(x) / sum(exp(x))
|
||||
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||
|
@ -88,6 +128,56 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
||||
// newGrad = gradOutput * output
|
||||
// result = newGrad - output * sum(newGrad, dim))
|
||||
//
|
||||
// Refer to
|
||||
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
||||
namespace {
|
||||
class DecomposeAten_SoftmaxBackwardDataOp
|
||||
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
Value gradOutput = op.grad_output();
|
||||
Value output = op.output();
|
||||
Value dim = op.dim();
|
||||
|
||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
Value newGrad =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
||||
// temp = output * sum(newGrad, dim)
|
||||
Value sum =
|
||||
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return failure();
|
||||
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
||||
Value broadcastSize =
|
||||
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
||||
Value sumBroadcast =
|
||||
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||
Value temp =
|
||||
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
|
||||
|
||||
// newGrad - temp
|
||||
Value alpha =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||
Value sub =
|
||||
rewriter.create<AtenSubTensorOp>(loc, tensorType, newGrad, temp, alpha);
|
||||
|
||||
rewriter.replaceOp(op, sub);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.log_softmax op into: log(softmax(x))
|
||||
namespace {
|
||||
class DecomposeAtenLogSoftmaxIntOp
|
||||
|
@ -177,6 +267,10 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenExpandOp>(context);
|
||||
target.addIllegalOp<AtenExpandOp>();
|
||||
patterns.add<DecomposeAtenSizeOp>(context);
|
||||
target.addIllegalOp<AtenSizeOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.self());
|
||||
|
|
|
@ -230,7 +230,7 @@ public:
|
|||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||
AtenFloorOp, AtenLog2Op>(op)) {
|
||||
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp>(op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -621,6 +621,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||
|
||||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
|
||||
|
||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||
|
|
|
@ -107,3 +107,17 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
|
|||
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
|
||||
return %ret : !torch.tensor<*,f32>
|
||||
}
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.size(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: return %[[SIZE]] : !torch.list<!torch.int>
|
||||
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
||||
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
|
||||
return %0 : !torch.list<!torch.int>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue