mirror of https://github.com/llvm/torch-mlir
Add decomposition of _log_softmax op.
Decompose _log_softmax into log(softmax(x)).pull/562/head
parent
318946a650
commit
102c497c4c
|
@ -625,6 +625,24 @@ def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class _LogSoftmaxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _LogSoftmaxModule())
|
||||
def _LogSoftmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
# ==============================================================================
|
||||
class BroadcastToModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1624,6 +1624,22 @@ def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [
|
|||
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($dtype)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_Aten_LogSoftmaxOp : Torch_Op<"aten._log_softmax", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_log_softmax : (Tensor, int, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
Torch_BoolType:$half_to_float
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` qualified(type($self)) `,` qualified(type($dim)) `,` qualified(type($half_to_float)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -468,6 +468,21 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten._log_softmax op into: log(_softmax(x))
|
||||
namespace {
|
||||
class DecomposeAten_LogSoftmaxOp : public OpRewritePattern<Aten_LogSoftmaxOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_LogSoftmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value softmax = rewriter.create<Aten_SoftmaxOp>(
|
||||
op.getLoc(), op.getType(), op.self(), op.dim(), op.half_to_float());
|
||||
rewriter.replaceOpWithNewOp<AtenLogOp>(op, op.getType(), softmax);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.matmul into: aten.mm and aten.bmm according to ranks.
|
||||
namespace {
|
||||
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||||
|
@ -981,6 +996,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||
patterns.add<DecomposeAten_LogSoftmaxOp>(context);
|
||||
target.addIllegalOp<Aten_LogSoftmaxOp>();
|
||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenEmptyLikeOp>(context);
|
||||
|
|
|
@ -480,7 +480,9 @@ public:
|
|||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
|
||||
return visitAten_SoftmaxOp(_softmaxOp, operands);
|
||||
return visitAten_SoftmaxLikeOp(_softmaxOp, operands);
|
||||
} else if (auto _logSoftmaxOp = dyn_cast<Aten_LogSoftmaxOp>(op)) {
|
||||
return visitAten_SoftmaxLikeOp(_logSoftmaxOp, operands);
|
||||
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||
|
@ -646,8 +648,9 @@ private:
|
|||
visitAtenAddCLikeOp(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
template <typename OpTy>
|
||||
ChangeResult
|
||||
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
|
||||
visitAten_SoftmaxLikeOp(OpTy op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
ChangeResult visitAtenNllLossForwardOp(
|
||||
|
@ -1800,8 +1803,10 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
|
||||
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
// Common template for softmax like ops, eg., log_softmax.(underscore variant)
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitAten_SoftmaxLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
|
||||
bool halfToFloat;
|
||||
|
|
|
@ -535,6 +535,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::_log_softmax : (Tensor, int, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
|
|
|
@ -342,3 +342,28 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
|
|||
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int> -> !torch.vtensor<[512,32],f32>
|
||||
return %1 : !torch.vtensor<[512,32],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @_log.softmax(
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[INP]], %[[INT0]], %[[TRUE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?,?],f32>, !torch.vtensor<[1,?,?],si64>
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[INP]], %[[VAL]], %[[FLOAT1]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[EXP:.*]] = torch.aten.exp %[[SUB]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[PRIM:.*]] = torch.prim.ListConstruct %[[INT0]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[TRU:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM_DIM:.*]] = torch.aten.sum.dim_IntList %[[EXP]], %[[PRIM]], %[[TRU]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
|
||||
// CHECK: %[[SOFTMAX:.*]] = torch.aten.div.Tensor %[[EXP]], %[[SUM_DIM]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[SOFTMAX]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[LOG:.*]] = torch.aten.log %[[CAST]] : !torch.vtensor<[?,?,?],f32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[LOG]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @_log.softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%false = torch.constant.bool false
|
||||
%0 = torch.aten._log_softmax %arg0, %int0, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue