mirror of https://github.com/llvm/torch-mlir
Modify softmax decomposition to be more numerically stable.
The softmax decomposition is modified according to https://github.com/pytorch/functorch/blob/main/functorch/_src/decompositions.pytorch to account for numerical stability. Also, modified aten.argmax lowering to handle negative dimension.pull/558/head snapshot-20220203.246
parent
0079901039
commit
68acc8696e
|
@ -216,3 +216,21 @@ class ReduceMaxAllDims(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ReduceMaxAllDims())
|
||||
def ReduceMaxAllDims_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-10, high=-5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceMaxNegativeDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, -1, keepdim=True)
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceMaxNegativeDim())
|
||||
def ReduceMaxNegativeDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
|
|
@ -2206,6 +2206,9 @@ public:
|
|||
if (!matchPattern(maxDimOp.dim(), m_TorchConstantInt(&dim)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
maxDimOp, "aten.max_dim to linalg.* requires int value for Dim");
|
||||
dim = toPositiveDim(dim, inputType.getRank());
|
||||
if (!isValidDim(dim, inputType.getRank()))
|
||||
return rewriter.notifyMatchFailure(maxDimOp, "dim is not a valid dim");
|
||||
|
||||
Type inElementType = inputType.getElementType();
|
||||
if (!inElementType.isa<mlir::FloatType>()) {
|
||||
|
|
|
@ -33,14 +33,12 @@ static int getTensorRank(Value tensor) {
|
|||
return tensorRank;
|
||||
}
|
||||
|
||||
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
// Helper function to compute the return type of the reduction function.
|
||||
// `dim` specifies the dimension to reduce and `keepDim` preserves the rank of
|
||||
// the input tensor.
|
||||
static Type computeReductionType(PatternRewriter &rewriter, Operation *op,
|
||||
Value input, Value dim, bool keepDim) {
|
||||
BaseTensorType tensorType = input.getType().cast<BaseTensorType>();
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(dim.getType()), dim);
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
SmallVector<int64_t> sizes;
|
||||
int64_t dimInt;
|
||||
if (tensorType.hasSizes()) {
|
||||
|
@ -53,9 +51,15 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
return nullptr;
|
||||
}
|
||||
sizes.append(inputShape.begin(), inputShape.end());
|
||||
sizes[dimInt] = 1;
|
||||
// The dimension to be reduced is set to 1 when `keepDim` is true else it
|
||||
// is removed.
|
||||
if (keepDim)
|
||||
sizes[dimInt] = 1;
|
||||
else
|
||||
sizes.erase(sizes.begin() + dimInt - 1);
|
||||
} else {
|
||||
sizes.resize(inputRank, kUnknownSize);
|
||||
unsigned reducedRank = keepDim ? inputRank : inputRank - 1;
|
||||
sizes.resize(reducedRank, kUnknownSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -63,9 +67,44 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
|||
sizes.size() == 0 ? Optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(sizes),
|
||||
tensorType.getDtype());
|
||||
Value sum = rewriter.create<AtenSumDimIntListOp>(loc, resultType, input,
|
||||
dimList, keepDimCst, dtype);
|
||||
return sum;
|
||||
return resultType;
|
||||
}
|
||||
|
||||
// Reduction function to calculate sum along given `dim`.
|
||||
static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(dim.getType()), dim);
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||
Type resultType = computeReductionType(rewriter, op, input, dim, keepDim);
|
||||
if (!resultType)
|
||||
return nullptr;
|
||||
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
||||
keepDimCst, dtype);
|
||||
}
|
||||
|
||||
// Redunction function to calculate max along given `dim`.
|
||||
static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||
Operation *op, Value input, Value dim,
|
||||
bool keepDim) {
|
||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||
BaseTensorType valueType =
|
||||
computeReductionType(rewriter, op, input, dim, keepDim)
|
||||
.cast<BaseTensorType>();
|
||||
if (!valueType)
|
||||
return nullptr;
|
||||
BaseTensorType indexType =
|
||||
valueType
|
||||
.getWithSizesAndDtype(
|
||||
!valueType.hasSizes() ? Optional<ArrayRef<int64_t>>()
|
||||
: llvm::makeArrayRef(valueType.getSizes()),
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed))
|
||||
.cast<BaseTensorType>();
|
||||
return rewriter
|
||||
.create<AtenMaxDimOp>(loc, valueType, indexType, input, dim, keepDimCst)
|
||||
.values();
|
||||
}
|
||||
|
||||
// Helper for creating `aten::sub_tensor_op`.
|
||||
|
@ -167,22 +206,29 @@ public:
|
|||
|
||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||
// exp(x)/sum(exp(x)).
|
||||
// To avoid overflow we use the following decomposition rule:
|
||||
// x_max = max(input, dim, keepdim = True)
|
||||
// unnorm = aten.exp(input - x_max)
|
||||
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
|
||||
template <typename OpTy>
|
||||
static Value getSoftmaxResult(OpTy op, Type resultType,
|
||||
PatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
Value dim = op.dim();
|
||||
Value self = op.self();
|
||||
|
||||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
|
||||
// sum(exp(x))
|
||||
Value sum =
|
||||
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
Value xMax =
|
||||
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
||||
if (!xMax)
|
||||
return nullptr;
|
||||
Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax);
|
||||
Value unNormalizedExp =
|
||||
rewriter.create<AtenExpOp>(loc, resultType, unNormalized);
|
||||
Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim,
|
||||
/*keepDim=*/true);
|
||||
if (!sum)
|
||||
return nullptr;
|
||||
// exp(x) / sum(exp(x))
|
||||
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
|
||||
return rewriter.create<AtenDivTensorOp>(loc, resultType, unNormalizedExp,
|
||||
sum);
|
||||
}
|
||||
|
||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||
|
|
|
@ -30,7 +30,13 @@ func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vten
|
|||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[KEEP_DIM0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[KEEP_DIM0]] :
|
||||
// CHECK-SAME: !torch.tensor<[2,3],f32>, !torch.int, !torch.bool -> !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],si64>
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.float -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
|
||||
|
@ -45,12 +51,19 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
|
|||
return %ret : !torch.tensor<[2,3],f32>
|
||||
}
|
||||
|
||||
|
||||
// ----
|
||||
// CHECK-LABEL: func @torch.aten.softmax.int$cst_dim(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>) -> !torch.tensor<[2,3],f32> {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[TRU:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[2,3],f32>, !torch.int, !torch.bool ->
|
||||
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.tensor<[2,1],si64>
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: !torch.tensor<[2,1],f32>, !torch.float -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
|
||||
|
@ -71,7 +84,13 @@ func @torch.aten.softmax.int$cst_dim(%t: !torch.tensor<[2,3],f32>) -> !torch.ten
|
|||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
|
||||
// CHECK: %[[TRU:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool ->
|
||||
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.tensor<[?,1],si64>
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<[?,?],f32>,
|
||||
// CHECK-SAME: !torch.tensor<[?,1],f32>, !torch.float -> !torch.tensor<[?,?],f32>
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
|
||||
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
|
||||
|
@ -92,7 +111,13 @@ func @torch.aten.softmax.int$dyn_shape(%t: !torch.tensor<[?,?],f32>) -> !torch.t
|
|||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
||||
// CHECK: %[[DTYPE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DIM:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[T]] : !torch.tensor<*,f32> -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[TRU:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[T]], %[[DIM]], %[[TRU]] : !torch.tensor<*,f32>, !torch.int, !torch.bool
|
||||
// CHECK-SAME: -> !torch.tensor<*,f32>, !torch.tensor<*,si64>
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[T]], %[[VAL]], %[[FLOAT1]] : !torch.tensor<*,f32>, !torch.tensor<*,f32>,
|
||||
// CHECK-SAME: !torch.float -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.tensor<*,f32> -> !torch.tensor<*,f32>
|
||||
// CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[DIM]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[KEEP_DIM:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[SUM_DTYPE:.*]] = torch.constant.none
|
||||
|
|
Loading…
Reference in New Issue