mirror of https://github.com/llvm/torch-mlir
[TORCH][MLIR] Add E2E support for aten._softmax operation. (#431)
Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/437/head snapshot-20211125.105
parent
67ce816fca
commit
f461a7ebce
|
@ -457,6 +457,23 @@ class SoftmaxIntModule(torch.nn.Module):
|
||||||
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
def SoftmaxIntModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.randn(3, 2, 4))
|
module.forward(torch.randn(3, 2, 4))
|
||||||
|
|
||||||
|
class _SoftmaxModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor):
|
||||||
|
return torch.ops.aten._softmax(tensor, 0, False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: _SoftmaxModule())
|
||||||
|
def _SoftmaxModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randn(3, 2, 4))
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxIntNegDimModule(torch.nn.Module):
|
class SoftmaxIntNegDimModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
|
@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
|
||||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
Torch_IntType:$dim,
|
||||||
|
Torch_BoolType:$half_to_float
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -126,6 +126,26 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||||
|
// exp(x)/sum(exp(x)).
|
||||||
|
template <typename OpTy>
|
||||||
|
static Value getSoftmaxResult(OpTy op, Type resultType,
|
||||||
|
PatternRewriter &rewriter) {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value dim = op.dim();
|
||||||
|
Value self = op.self();
|
||||||
|
|
||||||
|
// exp(x)
|
||||||
|
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
|
||||||
|
// sum(exp(x))
|
||||||
|
Value sum =
|
||||||
|
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
||||||
|
if (!sum)
|
||||||
|
return nullptr;
|
||||||
|
// exp(x) / sum(exp(x))
|
||||||
|
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
|
||||||
|
}
|
||||||
|
|
||||||
// Decompose softmax into: exp(x) / sum(exp(x))
|
// Decompose softmax into: exp(x) / sum(exp(x))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
|
||||||
|
@ -133,9 +153,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
|
||||||
Value self = op.self();
|
Value self = op.self();
|
||||||
Value dim = op.dim();
|
|
||||||
if (!op.dtype().getType().isa<Torch::NoneType>())
|
if (!op.dtype().getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented non-None dtype for softmax");
|
op, "Unimplemented non-None dtype for softmax");
|
||||||
|
@ -144,14 +162,40 @@ public:
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
// exp(x)
|
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||||||
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self);
|
if (!result)
|
||||||
// sum(exp(x))
|
|
||||||
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
|
|
||||||
if (!sum)
|
|
||||||
return failure();
|
return failure();
|
||||||
// exp(x) / sum(exp(x))
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||||
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum);
|
result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value self = op.self();
|
||||||
|
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
||||||
|
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
||||||
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
bool halfToFloat;
|
||||||
|
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Expected a boolean value for half_to_float");
|
||||||
|
|
||||||
|
// Currently, setting `halfToFloat` is not supported as the E2E testing for
|
||||||
|
// the same is not present on CPU.
|
||||||
|
if (halfToFloat)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "halfToFloat is currently not supported.");
|
||||||
|
|
||||||
|
Value result = getSoftmaxResult(op, tensorType, rewriter);
|
||||||
|
if (!result)
|
||||||
|
return op.emitError("failed to get softmax result");
|
||||||
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
|
||||||
result);
|
result);
|
||||||
return success();
|
return success();
|
||||||
|
@ -406,6 +450,8 @@ class DecomposeComplexOpsPass
|
||||||
|
|
||||||
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
patterns.add<DecomposeAtenSoftmaxIntOp>(context);
|
||||||
target.addIllegalOp<AtenSoftmaxIntOp>();
|
target.addIllegalOp<AtenSoftmaxIntOp>();
|
||||||
|
patterns.add<DecomposeAten_SoftmaxOp>(context);
|
||||||
|
target.addIllegalOp<Aten_SoftmaxOp>();
|
||||||
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
|
||||||
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
target.addIllegalOp<AtenLogSoftmaxIntOp>();
|
||||||
patterns.add<DecomposeAtenExpandOp>(context);
|
patterns.add<DecomposeAtenExpandOp>(context);
|
||||||
|
|
|
@ -418,6 +418,8 @@ public:
|
||||||
return visitAtenMatmulOp(matmul, operands);
|
return visitAtenMatmulOp(matmul, operands);
|
||||||
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||||
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||||
|
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
|
||||||
|
return visitAten_SoftmaxOp(_softmaxOp, operands);
|
||||||
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
|
||||||
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
|
||||||
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
|
||||||
|
@ -541,6 +543,10 @@ private:
|
||||||
ChangeResult
|
ChangeResult
|
||||||
visitAtenAddCLikeOp(Operation *op,
|
visitAtenAddCLikeOp(Operation *op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
|
||||||
|
ChangeResult
|
||||||
|
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
@ -1332,6 +1338,16 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static ValueKnowledge
|
||||||
|
getSameSizeAsInput(Operation *op,
|
||||||
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||||
|
knowledge.hasSizes = input.hasSizes;
|
||||||
|
knowledge.sizes = input.sizes;
|
||||||
|
return knowledge;
|
||||||
|
}
|
||||||
|
|
||||||
// Common template for softmax like ops, eg., log_softmax.
|
// Common template for softmax like ops, eg., log_softmax.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
|
@ -1339,14 +1355,23 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
|
||||||
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
auto dtype = op.dtype();
|
auto dtype = op.dtype();
|
||||||
auto knowledge =
|
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
|
||||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
|
||||||
knowledge.hasSizes = input.hasSizes;
|
|
||||||
knowledge.sizes = input.sizes;
|
|
||||||
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
|
||||||
|
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
|
||||||
|
bool halfToFloat;
|
||||||
|
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
|
||||||
|
knowledge.dtype =
|
||||||
|
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
|
||||||
|
}
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
ChangeResult TypeAnalyzer::visitAtenBmmOp(
|
||||||
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
|
|
|
@ -516,6 +516,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
|
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
|
||||||
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::sqrt : (Tensor) -> (Tensor)")
|
emit("aten::sqrt : (Tensor) -> (Tensor)")
|
||||||
|
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||||
|
|
||||||
# Misc tensor ops.
|
# Misc tensor ops.
|
||||||
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue