mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for aten._softmax operation. (#431)
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/437/head snapshot-20211125.105
parent
67ce816fca
commit
f461a7ebce
|
@ -457,6 +457,23 @@ class SoftmaxIntModule(torch.nn.Module):
|
|||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
class _SoftmaxModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._softmax(tensor, 0, False)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: _SoftmaxModule())
|
||||
def _SoftmaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4))
|
||||
|
||||
|
||||
class SoftmaxIntNegDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
|
|||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
Torch_BoolType:$half_to_float
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -126,6 +126,26 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||
// exp(x)/sum(exp(x)).
|
||||
template <typename OpTy>
|
||||
static Value getSoftmaxResult(OpTy op, Type resultType,
|
||||
PatternRewriter &rewriter) {
|
||||
Location loc = op.getLoc();
|
||||
Value dim = op.dim();
|
||||
Value self = op.self();
|
||||
|
||||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
|
||||
// sum(exp(x))
|
||||
Value sum =
|
||||
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
return nullptr;
|
||||
// exp(x) / sum(exp(x))
|
||||
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
|
||||
}
|
||||
|
||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||
namespace {
|
||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||
|
@ -133,9 +153,7 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value self = op.self();
|
||||
Value dim = op.dim();
|
||||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented non-None dtype for softmax");
|
||||
|
@ -144,14 +162,40 @@ public:
|
|||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
|
||||
// exp(x)
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
||||
// sum(exp(x))
|
||||
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||
if (!sum)
|
||||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||||
if (!result)
|
||||
return failure();
|
||||
// exp(x) / sum(exp(x))
|
||||
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||
result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value self = op.self();
|
||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||
bool halfToFloat;
|
||||
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected a boolean value for half_to_float");
|
||||
|
||||
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
||||
// the same is not present on CPU.
|
||||
if (halfToFloat)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "halfToFloat is currently not supported.");
|
||||
|
||||
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||||
if (!result)
|
||||
return op.emitError("failed to get softmax result");
|
||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||
result);
|
||||
return success();
|
||||
|
@ -406,6 +450,8 @@ class DecomposeComplexOpsPass
|
|||
|
||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
||||
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||
patterns.add<DecomposeAtenExpandOp>(context);
|
||||
|
|
|
@ -418,6 +418,8 @@ public:
|
|||
return visitAtenMatmulOp(matmul, operands);
|
||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
|
||||
return visitAten_SoftmaxOp(_softmaxOp, operands);
|
||||
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
||||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||
|
@ -541,6 +543,10 @@ private:
|
|||
ChangeResult
|
||||
visitAtenAddCLikeOp(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
|
||||
ChangeResult
|
||||
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -1332,6 +1338,16 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
static ValueKnowledge
|
||||
getSameSizeAsInput(Operation *op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
knowledge.sizes = input.sizes;
|
||||
return knowledge;
|
||||
}
|
||||
|
||||
// Common template for softmax like ops, eg., log_softmax.
|
||||
template <typename OpTy>
|
||||
|
@ -1339,14 +1355,23 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
|||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto dtype = op.dtype();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
knowledge.sizes = input.sizes;
|
||||
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
|
||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
|
||||
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
|
||||
bool halfToFloat;
|
||||
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
|
||||
knowledge.dtype =
|
||||
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
|
||||
}
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
||||
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto knowledge =
|
||||
|
|
|
@ -516,6 +516,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::sqrt : (Tensor) -> (Tensor)")
|
||||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
|
||||
# Misc tensor ops.
|
||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue