Modify softmax decomposition to be more numerically stable.

The softmax decomposition is modified according to https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.pytorch
to account for numerical stability. Also, modified aten.argmax lowering
to handle negative dimension.
pull/558/head snapshot-20220203.246
Prashant Kumar 2022-02-01 01:26:32 +05:30
parent 0079901039
commit 68acc8696e
4 changed files with 116 additions and 24 deletions

View File

@ -216,3 +216,21 @@ class ReduceMaxAllDims(torch.nn.Module):
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
def ReduceMaxAllDims_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
# ==============================================================================
class ReduceMaxNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, a):
return torch.ops.aten.max(a, -1, keepdim=True)
@register_test_case(module_factory=lambda: ReduceMaxNegativeDim())
def ReduceMaxNegativeDim_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))

View File

@ -2206,6 +2206,9 @@ public:
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
dim = toPositiveDim(dim, inputType.getRank());
if (!isValidDim(dim, inputType.getRank()))
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
Type inElementType = inputType.getElementType();
if (!inElementType.isa<mlir::FloatType>()) {

View File

@ -33,14 +33,12 @@ static int getTensorRank(Value tensor) {
return tensorRank;
}
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
// Helper function to compute the return type of the reduction function.
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
// the input tensor.
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
Value input, Value dim, bool keepDim) {
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
SmallVector<int64_t> sizes;
int64_t dimInt;
if (tensorType.hasSizes()) {
@ -53,9 +51,15 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
return nullptr;
}
sizes.append(inputShape.begin(), inputShape.end());
sizes[dimInt] = 1;
// The dimension to be reduced is set to 1 when `keepDim` is true else it
// is removed.
if (keepDim)
sizes[dimInt] = 1;
else
sizes.erase(sizes.begin() + dimInt - 1);
} else {
sizes.resize(inputRank, kUnknownSize);
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
sizes.resize(reducedRank, kUnknownSize);
}
}
@ -63,9 +67,44 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(sizes),
tensorType.getDtype());
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
dimList, keepDimCst, dtype);
return sum;
return resultType;
}
// Reduction function to calculate sum along given `dim`.
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(dim.getType()), dim);
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
Value dtype = rewriter.create<ConstantNoneOp>(loc);
Type resultType = computeReductionType(rewriter, op, input, dim, keepDim);
if (!resultType)
return nullptr;
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
keepDimCst, dtype);
}
// Redunction function to calculate max along given `dim`.
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
Operation *op, Value input, Value dim,
bool keepDim) {
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
BaseTensorType valueType =
computeReductionType(rewriter, op, input, dim, keepDim)
.cast<BaseTensorType>();
if (!valueType)
return nullptr;
BaseTensorType indexType =
valueType
.getWithSizesAndDtype(
!valueType.hasSizes() ? Optional<ArrayRef<int64_t>>()
: llvm::makeArrayRef(valueType.getSizes()),
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
.cast<BaseTensorType>();
return rewriter
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
.values();
}
// Helper for creating `aten::sub_tensor_op`.
@ -167,22 +206,29 @@ public:
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule:
// x_max = max(input, dim, keepdim = True)
// unnorm = aten.exp(input - x_max)
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Type resultType,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();
// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
// sum(exp(x))
Value sum =
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
Value xMax =
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
if (!xMax)
return nullptr;
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
Value unNormalizedExp =
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
/*keepDim=*/true);
if (!sum)
return nullptr;
// exp(x) / sum(exp(x))
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
sum);
}
// Decompose softmax into: exp(x) / sum(exp(x))

View File

@ -30,7 +30,13 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[KEEP_DIM0:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] :
// CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
@ -45,12 +51,19 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
return %ret : !torch.tensor<[2,3],f32>
}
// ----
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool ->
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.tensor<[2,1],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
@ -71,7 +84,13 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool ->
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.tensor<[?,1],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>,
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.float -> !torch.tensor<[?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
@ -92,7 +111,13 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK: %[[DTYPE:.*]] = torch.constant.none
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool
// CHECK-SAME: -> !torch.tensor<*,f32>, !torch.tensor<*,si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<*,f32>, !torch.tensor<*,f32>,
// CHECK-SAME: !torch.float -> !torch.tensor<*,f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<*,f32> -> !torch.tensor<*,f32>
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none