mirror of https://github.com/llvm/torch-mlir
Add e2e support for aten._softmax_backward_data.
Decompose aten._softmax_backward_data into aten math ops. Also decompose `aten.size` to facilitate decomposing _softmax_backward_data.pull/407/head snapshot-20211109.73
parent
05c4dd8e39
commit
3bd9d2a4c7
|
@ -0,0 +1,35 @@
|
||||||
|
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
# See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
# Also available under a BSD-style license. See LICENSE.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||||
|
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||||
|
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class SoftmaxBackwardModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, grad_output, output):
|
||||||
|
return torch.ops.aten._softmax_backward_data(grad_output,
|
||||||
|
output,
|
||||||
|
dim=1,
|
||||||
|
input_dtype=6)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: SoftmaxBackwardModule())
|
||||||
|
def SoftmaxBackwardModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 2, 4), torch.randn(3, 2, 4))
|
||||||
|
|
|
@ -374,7 +374,6 @@ def EmbeddingModule_basic(module, tu: TestUtils):
|
||||||
class SoftmaxIntModule(torch.nn.Module):
|
class SoftmaxIntModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
torch.manual_seed(0)
|
|
||||||
self.softmax = torch.nn.Softmax(2)
|
self.softmax = torch.nn.Softmax(2)
|
||||||
|
|
||||||
@export
|
@export
|
||||||
|
@ -429,6 +428,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
||||||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4).double())
|
module.forward(torch.randn(3, 2, 4).double())
|
||||||
|
|
||||||
|
|
||||||
class BroadcastToModule(torch.nn.Module):
|
class BroadcastToModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -522,6 +522,7 @@ class TensorToInt(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return int(x)
|
return int(x)
|
||||||
|
|
||||||
|
|
||||||
@register_test_case(module_factory=lambda: TensorToInt())
|
@register_test_case(module_factory=lambda: TensorToInt())
|
||||||
def TensorToInt_basic(module, tu: TestUtils):
|
def TensorToInt_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randint(10,[]))
|
module.forward(torch.randint(10,[]))
|
||||||
|
@ -543,6 +544,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
||||||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4).double())
|
module.forward(torch.randn(3, 2, 4).double())
|
||||||
|
|
||||||
|
|
||||||
class NumToTensorModule(torch.nn.Module):
|
class NumToTensorModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -35,6 +35,7 @@ from . import quantized_models
|
||||||
from . import elementwise
|
from . import elementwise
|
||||||
from . import type_promotion
|
from . import type_promotion
|
||||||
from . import type_conversion
|
from . import type_conversion
|
||||||
|
from . import backprop
|
||||||
from . import reduction
|
from . import reduction
|
||||||
from . import argmax
|
from . import argmax
|
||||||
from . import matmul
|
from . import matmul
|
||||||
|
|
|
@ -2785,3 +2785,20 @@ def Torch_AtenEqDeviceOp : Torch_Op<"aten.eq.device", [
|
||||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$grad_output,
|
||||||
|
AnyTorchTensorType:$output,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
Torch_IntType:$input_dtype
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$grad_output `,` $output `,` $dim `,` $input_dtype attr-dict `:` type($grad_output) `,` type($output) `,` type($dim) `,` type($input_dtype) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,68 @@ static int getTensorRank(Value tensor) {
|
||||||
return tensorRank;
|
return tensorRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createAtenSum(PatternRewriter &rewriter, Location loc,
|
||||||
|
Operation *op, Value input, Value dim,
|
||||||
|
bool keepDim) {
|
||||||
|
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||||||
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(dim.getType()), dim);
|
||||||
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||||
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
SmallVector<int64_t> sizes;
|
||||||
|
int64_t dimInt;
|
||||||
|
if (tensorType.hasSizes()) {
|
||||||
|
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
||||||
|
int64_t inputRank = inputShape.size();
|
||||||
|
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
||||||
|
dimInt = toPositiveDim(dimInt, inputRank);
|
||||||
|
if (!isValidDim(dimInt, inputRank)) {
|
||||||
|
(void)rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
sizes.append(inputShape.begin(), inputShape.end());
|
||||||
|
sizes[dimInt] = 1;
|
||||||
|
} else {
|
||||||
|
sizes.resize(inputRank, kUnknownSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Type resultType = tensorType.getWithSizesAndDtype(
|
||||||
|
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||||||
|
: llvm::makeArrayRef(sizes),
|
||||||
|
tensorType.getDtype());
|
||||||
|
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
|
||||||
|
dimList, keepDimCst, dtype);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenSizeOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value self = op.self();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
int64_t rank = getTensorRank(self);
|
||||||
|
if (rank < 0)
|
||||||
|
return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor");
|
||||||
|
SmallVector<Value> sizes;
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
Value dim = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i));
|
||||||
|
sizes.push_back(rewriter.create<AtenSizeIntOp>(loc, self, dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(context)), sizes);
|
||||||
|
rewriter.replaceOp(op, sizeList);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||||
|
@ -50,35 +112,13 @@ public:
|
||||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
// exp(x)
|
// exp(x)
|
||||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||||
|
|
||||||
// sum(exp(x))
|
// sum(exp(x))
|
||||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
Value sum = createAtenSum(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||||
loc, Torch::ListType::get(dim.getType()), dim);
|
if (!sum)
|
||||||
Value keepDim = rewriter.create<ConstantBoolOp>(loc, true);
|
return failure();
|
||||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
|
||||||
SmallVector<int64_t> sizes;
|
|
||||||
int64_t dimInt;
|
|
||||||
if (tensorType.hasSizes()) {
|
|
||||||
ArrayRef<int64_t> inputShape = tensorType.getSizes();
|
|
||||||
int64_t inputRank = inputShape.size();
|
|
||||||
if (matchPattern(dim, m_TorchConstantInt(&dimInt))) {
|
|
||||||
dimInt = toPositiveDim(dimInt, inputRank);
|
|
||||||
if (!isValidDim(dimInt, inputRank))
|
|
||||||
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
|
|
||||||
sizes.append(inputShape.begin(), inputShape.end());
|
|
||||||
sizes[dimInt] = 1;
|
|
||||||
} else {
|
|
||||||
sizes.resize(inputRank, kUnknownSize);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Type resultType = tensorType.getWithSizesAndDtype(
|
|
||||||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
|
||||||
: llvm::makeArrayRef(sizes),
|
|
||||||
tensorType.getDtype());
|
|
||||||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, exp,
|
|
||||||
dimList, keepDim, dtype);
|
|
||||||
// exp(x) / sum(exp(x))
|
// exp(x) / sum(exp(x))
|
||||||
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
||||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||||
|
@ -88,6 +128,56 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Aten_SoftmaxBackwardDataOp(gradOutput, output, dim) =>
|
||||||
|
// newGrad = gradOutput * output
|
||||||
|
// result = newGrad - output * sum(newGrad, dim))
|
||||||
|
//
|
||||||
|
// Refer to
|
||||||
|
// https://github.com/pytorch/pytorch/blob/15fecc4c830a3907fde4b44c9962dc4144da50a4/torch/csrc/jit/codegen/cuda/ops/normalization.cpp#L31
|
||||||
|
namespace {
|
||||||
|
class DecomposeAten_SoftmaxBackwardDataOp
|
||||||
|
: public OpRewritePattern<Aten_SoftmaxBackwardDataOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Aten_SoftmaxBackwardDataOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
MLIRContext *context = op.getContext();
|
||||||
|
Value gradOutput = op.grad_output();
|
||||||
|
Value output = op.output();
|
||||||
|
Value dim = op.dim();
|
||||||
|
|
||||||
|
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
||||||
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
|
Value newGrad =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, gradOutput, output);
|
||||||
|
// temp = output * sum(newGrad, dim)
|
||||||
|
Value sum =
|
||||||
|
createAtenSum(rewriter, loc, op, newGrad, dim, /*keepDim=*/true);
|
||||||
|
if (!sum)
|
||||||
|
return failure();
|
||||||
|
auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(context));
|
||||||
|
Value broadcastSize =
|
||||||
|
rewriter.create<AtenSizeOp>(loc, broadcastSizeType, output);
|
||||||
|
Value sumBroadcast =
|
||||||
|
rewriter.create<AtenBroadcastToOp>(loc, tensorType, sum, broadcastSize);
|
||||||
|
Value temp =
|
||||||
|
rewriter.create<AtenMulTensorOp>(loc, tensorType, output, sumBroadcast);
|
||||||
|
|
||||||
|
// newGrad - temp
|
||||||
|
Value alpha =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
|
||||||
|
Value sub =
|
||||||
|
rewriter.create<AtenSubTensorOp>(loc, tensorType, newGrad, temp, alpha);
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, sub);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.log_softmax op into: log(softmax(x))
|
// Decompose aten.log_softmax op into: log(softmax(x))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenLogSoftmaxIntOp
|
class DecomposeAtenLogSoftmaxIntOp
|
||||||
|
@ -177,6 +267,10 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||||
patterns.add<DecomposeAtenExpandOp>(context);
|
patterns.add<DecomposeAtenExpandOp>(context);
|
||||||
target.addIllegalOp<AtenExpandOp>();
|
target.addIllegalOp<AtenExpandOp>();
|
||||||
|
patterns.add<DecomposeAtenSizeOp>(context);
|
||||||
|
target.addIllegalOp<AtenSizeOp>();
|
||||||
|
patterns.add<DecomposeAten_SoftmaxBackwardDataOp>(context);
|
||||||
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||||
int lhsRank = getTensorRank(op.self());
|
int lhsRank = getTensorRank(op.self());
|
||||||
|
|
|
@ -230,7 +230,7 @@ public:
|
||||||
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp,
|
||||||
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op,
|
||||||
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenSqrtOp,
|
||||||
AtenFloorOp, AtenLog2Op>(op)) {
|
AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp>(op)) {
|
||||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -621,6 +621,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::div : (Scalar, Scalar) -> (float)")
|
emit("aten::div : (Scalar, Scalar) -> (float)")
|
||||||
emit("aten::eq.device : (Device, Device) -> (bool)")
|
emit("aten::eq.device : (Device, Device) -> (bool)")
|
||||||
|
|
||||||
|
# backprop ops
|
||||||
|
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||||
|
|
||||||
|
|
||||||
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
def emit_quantized_ops(torch_ir_dir: str, registry: Registry):
|
||||||
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
td_file = os.path.join(torch_ir_dir, "GeneratedQuantizedOps.td")
|
||||||
|
|
|
@ -107,3 +107,17 @@ func @torch.aten.softmax.int$unknown_shape(%t: !torch.tensor<*,f32>) -> !torch.t
|
||||||
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
|
%ret = torch.aten.softmax.int %t, %dim, %none : !torch.tensor<*,f32>, !torch.int, !torch.none -> !torch.tensor<*,f32>
|
||||||
return %ret : !torch.tensor<*,f32>
|
return %ret : !torch.tensor<*,f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----
|
||||||
|
// CHECK-LABEL: func @torch.aten.size(
|
||||||
|
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[T]], %[[CST0]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[T]], %[[CST1]] : !torch.vtensor<[?,3],f32>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||||
|
// CHECK: return %[[SIZE]] : !torch.list<!torch.int>
|
||||||
|
func @torch.aten.size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.list<!torch.int> {
|
||||||
|
%0 = torch.aten.size %arg0 : !torch.vtensor<[?,3],f32> -> !torch.list<!torch.int>
|
||||||
|
return %0 : !torch.list<!torch.int>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue