Add decomposition of _log_softmax op.

Decompose _log_softmax into log(softmax(x)).
pull/562/head
Prashant Kumar 2022-02-10 07:05:23 +00:00
parent 318946a650
commit 102c497c4c
6 changed files with 88 additions and 4 deletions

View File

@ -625,6 +625,24 @@ def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class _LogSoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
@register_test_case(module_factory=lambda: _LogSoftmaxModule())
def _LogSoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
# ==============================================================================
class BroadcastToModule(torch.nn.Module): class BroadcastToModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()

View File

@ -1624,6 +1624,22 @@ def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($dtype)) `->` qualified(type($result))"; let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($dtype)) `->` qualified(type($result))";
} }
def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::_log_softmax : (Tensor, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$half_to_float
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($half_to_float)) `->` qualified(type($result))";
}
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics HasValueSemantics

View File

@ -468,6 +468,21 @@ public:
}; };
} // namespace } // namespace
// Decompose aten._log_softmax op into: log(_softmax(x))
namespace {
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
PatternRewriter &rewriter) const override {
Value softmax = rewriter.create<Aten_SoftmaxOp>(
op.getLoc(), op.getType(), op.self(), op.dim(), op.half_to_float());
rewriter.replaceOpWithNewOp<AtenLogOp>(op, op.getType(), softmax);
return success();
}
};
} // namespace
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. // Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
namespace { namespace {
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> { class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
@ -981,6 +996,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenSoftmaxIntOp>(); target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAten_SoftmaxOp>(context); patterns.add<DecomposeAten_SoftmaxOp>(context);
target.addIllegalOp<Aten_SoftmaxOp>(); target.addIllegalOp<Aten_SoftmaxOp>();
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
target.addIllegalOp<Aten_LogSoftmaxOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context); patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>(); target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenEmptyLikeOp>(context); patterns.add<DecomposeAtenEmptyLikeOp>(context);

View File

@ -480,7 +480,9 @@ public:
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) { } else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands); return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) { } else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
return visitAten_SoftmaxOp(_softmaxOp, operands); return visitAten_SoftmaxLikeOp(_softmaxOp, operands);
} else if (auto _logSoftmaxOp = dyn_cast<Aten_LogSoftmaxOp>(op)) {
return visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands);
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) { } else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) { } else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
@ -646,8 +648,9 @@ private:
visitAtenAddCLikeOp(Operation *op, visitAtenAddCLikeOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
template <typename OpTy>
ChangeResult ChangeResult
visitAten_SoftmaxOp(Aten_SoftmaxOp op, visitAten_SoftmaxLikeOp(OpTy op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitAtenNllLossForwardOp( ChangeResult visitAtenNllLossForwardOp(
@ -1800,8 +1803,10 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp( // Common template for softmax like ops, eg., log_softmax.(underscore variant)
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { template <typename OpTy>
ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
ValueKnowledge knowledge = getSameSizeAsInput(op, operands); ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
bool halfToFloat; bool halfToFloat;

View File

@ -535,6 +535,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit( emit(
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)" "aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
) )
emit(
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
)
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")

View File

@ -342,3 +342,28 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32> %1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32>
return %1 : !torch.vtensor<[512,32],f32> return %1 : !torch.vtensor<[512,32],f32>
} }
// -----
// CHECK-LABEL: func @_log.softmax(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64>
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
// CHECK: %[[TRU:.*]] = torch.constant.bool true
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32>
// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32>
func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}