[TORCH][MLIR] Add E2E support for aten._softmax operation. (#431)

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/437/head snapshot-20211125.105
Prateek Gupta 2021-11-25 11:19:02 +05:30 committed by GitHub
parent 67ce816fca
commit f461a7ebce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 118 additions and 13 deletions

View File

@ -457,6 +457,23 @@ class SoftmaxIntModule(torch.nn.Module):
def SoftmaxIntModule_basic(module, tu: TestUtils): def SoftmaxIntModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4)) module.forward(torch.randn(3, 2, 4))
class _SoftmaxModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, tensor):
return torch.ops.aten._softmax(tensor, 0, False)
@register_test_case(module_factory=lambda: _SoftmaxModule())
def _SoftmaxModule_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4))
class SoftmaxIntNegDimModule(torch.nn.Module): class SoftmaxIntNegDimModule(torch.nn.Module):
def __init__(self): def __init__(self):

View File

@ -1392,6 +1392,22 @@ def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)"; let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
} }
def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::_softmax : (Tensor, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$half_to_float
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $dim `,` $half_to_float attr-dict `:` type($self) `,` type($dim) `,` type($half_to_float) `->` type($result)";
}
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -126,6 +126,26 @@ public:
}; };
} // namespace } // namespace
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Type resultType,
PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.dim();
Value self = op.self();
// exp(x)
Value exp = rewriter.create<AtenExpOp>(loc, resultType, self);
// sum(exp(x))
Value sum =
createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
return nullptr;
// exp(x) / sum(exp(x))
return rewriter.create<AtenDivTensorOp>(loc, resultType, exp, sum);
}
// Decompose softmax into: exp(x) / sum(exp(x)) // Decompose softmax into: exp(x) / sum(exp(x))
namespace { namespace {
class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> { class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
@ -133,9 +153,7 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op, LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self(); Value self = op.self();
Value dim = op.dim();
if (!op.dtype().getType().isa<Torch::NoneType>()) if (!op.dtype().getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Unimplemented non-None dtype for softmax"); op, "Unimplemented non-None dtype for softmax");
@ -144,14 +162,40 @@ public:
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>()) if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type"); return rewriter.notifyMatchFailure(op, "Only support floating type");
// exp(x) Value result = getSoftmaxResult(op, tensorType, rewriter);
Value exp = rewriter.create<AtenExpOp>(loc, tensorType, self); if (!result)
// sum(exp(x))
Value sum = createSumAlongDimension(rewriter, loc, op, exp, dim, /*keepDim=*/true);
if (!sum)
return failure(); return failure();
// exp(x) / sum(exp(x)) rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
Value result = rewriter.create<AtenDivTensorOp>(loc, tensorType, exp, sum); result);
return success();
}
};
} // namespace
namespace {
class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
PatternRewriter &rewriter) const override {
Value self = op.self();
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "Only support floating type");
bool halfToFloat;
if (!matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat)))
return rewriter.notifyMatchFailure(
op, "Expected a boolean value for half_to_float");
// Currently, setting `halfToFloat` is not supported as the E2E testing for
// the same is not present on CPU.
if (halfToFloat)
return rewriter.notifyMatchFailure(
op, "halfToFloat is currently not supported.");
Value result = getSoftmaxResult(op, tensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(), rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
result); result);
return success(); return success();
@ -406,6 +450,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenSoftmaxIntOp>(context); patterns.add<DecomposeAtenSoftmaxIntOp>(context);
target.addIllegalOp<AtenSoftmaxIntOp>(); target.addIllegalOp<AtenSoftmaxIntOp>();
patterns.add<DecomposeAten_SoftmaxOp>(context);
target.addIllegalOp<Aten_SoftmaxOp>();
patterns.add<DecomposeAtenLogSoftmaxIntOp>(context); patterns.add<DecomposeAtenLogSoftmaxIntOp>(context);
target.addIllegalOp<AtenLogSoftmaxIntOp>(); target.addIllegalOp<AtenLogSoftmaxIntOp>();
patterns.add<DecomposeAtenExpandOp>(context); patterns.add<DecomposeAtenExpandOp>(context);

View File

@ -418,6 +418,8 @@ public:
return visitAtenMatmulOp(matmul, operands); return visitAtenMatmulOp(matmul, operands);
} else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) { } else if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(softmaxIntOp, operands); return visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
} else if (auto _softmaxOp = dyn_cast<Aten_SoftmaxOp>(op)) {
return visitAten_SoftmaxOp(_softmaxOp, operands);
} else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) { } else if (auto logSoftmaxIntOp = dyn_cast<AtenLogSoftmaxIntOp>(op)) {
return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands); return visitAtenSoftmaxLikeOp(logSoftmaxIntOp, operands);
} else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) { } else if (auto numToTensorOp = dyn_cast<PrimNumToTensorScalarOp>(op)) {
@ -541,6 +543,10 @@ private:
ChangeResult ChangeResult
visitAtenAddCLikeOp(Operation *op, visitAtenAddCLikeOp(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands); ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAten_SoftmaxOp(Aten_SoftmaxOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
}; };
} // namespace } // namespace
@ -1332,6 +1338,16 @@ ChangeResult TypeAnalyzer::visitAtenEmbeddingOp(
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }
static ValueKnowledge
getSameSizeAsInput(Operation *op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
return knowledge;
}
// Common template for softmax like ops, eg., log_softmax. // Common template for softmax like ops, eg., log_softmax.
template <typename OpTy> template <typename OpTy>
@ -1339,14 +1355,23 @@ ChangeResult TypeAnalyzer::visitAtenSoftmaxLikeOp(
OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { OpTy op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue(); auto input = operands[0]->getValue();
auto dtype = op.dtype(); auto dtype = op.dtype();
auto knowledge = ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype); fillInDTypeGivenDTypeIntAndInputDType(knowledge, dtype, input.dtype);
return getLatticeElement(op.getResult()).join(knowledge); return getLatticeElement(op.getResult()).join(knowledge);
} }
ChangeResult TypeAnalyzer::visitAten_SoftmaxOp(
Aten_SoftmaxOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
ValueKnowledge knowledge = getSameSizeAsInput(op, operands);
bool halfToFloat;
if (matchPattern(op.half_to_float(), m_TorchConstantBool(&halfToFloat))) {
knowledge.dtype =
halfToFloat ? Float32Type::get(op->getContext()) : input.dtype;
}
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenBmmOp( ChangeResult TypeAnalyzer::visitAtenBmmOp(
AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) { AtenBmmOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto knowledge = auto knowledge =

View File

@ -516,6 +516,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)") emit("aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::sqrt : (Tensor) -> (Tensor)") emit("aten::sqrt : (Tensor) -> (Tensor)")
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
# Misc tensor ops. # Misc tensor ops.
emit("aten::unsqueeze : (Tensor, int) -> (Tensor)") emit("aten::unsqueeze : (Tensor, int) -> (Tensor)")