mirror of https://github.com/llvm/torch-mlir
Add lowering of `aten.log_softmax` op.
The `aten.log_softmax` is decomposed into `aten.softmax` and `aten.log` op.pull/396/head
parent
127c7d8e27
commit
ef897dbb19
|
@ -512,3 +512,20 @@ class TensorToInt(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]), tu.rand())
|
||||
|
||||
class LogSoftmaxIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(2)
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.log_softmax.forward(tensor)
|
||||
|
||||
@register_test_case(module_factory=lambda: LogSoftmaxIntModule())
|
||||
def LogSoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
|
|
@ -1088,6 +1088,22 @@ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [
|
|||
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
TorchOptionalIntType:$dtype
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim `,` $dtype attr-dict `:` type($self) `,` type($dim) `,` type($dtype) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -88,6 +88,34 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.log_softmax op into: log(softmax(x))
|
||||
namespace {
|
||||
class DecomposeAtenLogSoftmaxIntOp
|
||||
: public OpRewritePattern<AtenLogSoftmaxIntOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value dim = op.dim();
|
||||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None dtype for log_softmax");
|
||||
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
// softmax(x, dim)
|
||||
Value softmax = rewriter.create<AtenSoftmaxIntOp>(loc, tensorType, self,
|
||||
dim, op.dtype());
|
||||
rewriter.replaceOpWithNewOp<AtenLogOp>(op, op.getType(), softmax);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.matmul into: torch.mm and torch.bmm according to ranks.
|
||||
namespace {
|
||||
class DecomposeAtenMatmulOp : public OpRewritePattern<AtenMatmulOp> {
|
||||
|
@ -125,6 +153,8 @@ class DecomposeComplexOpsPass
|
|||
|
||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
int lhsRank = getTensorRank(op.self());
|
||||
|
|
|
@ -411,7 +411,9 @@ public:
|
|||
} else if (auto matmul = dyn_cast<AtenMatmulOp>(op)) {
|
||||
return visitAtenMatmulOp(matmul, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxIntOp(softmaxIntOp, operands);
|
||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||
}
|
||||
|
||||
// Otherwise, this is an unknown operation. Just mark all results as
|
||||
|
@ -511,11 +513,13 @@ private:
|
|||
visitAtenBmmOp(AtenBmmOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenSoftmaxIntOp(AtenSoftmaxIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenMatmulOp(AtenMatmulOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
template <typename OpTy>
|
||||
ChangeResult
|
||||
visitAtenSoftmaxLikeOp(OpTy op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1259,8 +1263,11 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenSoftmaxIntOp(
|
||||
AtenSoftmaxIntOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
|
||||
// Common template for softmax like ops, eg., log_softmax.
|
||||
template <typename OpTy>
|
||||
ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto dtype = op.dtype();
|
||||
auto knowledge =
|
||||
|
|
|
@ -495,6 +495,9 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit(
|
||||
"aten::softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"aten::log_softmax.int : (Tensor, int, int?) -> (Tensor)"
|
||||
)
|
||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue