From bc9abbc1c97ed9c66768bf9d7675323732f27826 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 14 Dec 2021 00:31:10 +0530 Subject: [PATCH] [TORCH][MLIR] Add E2E support for `aten.empty_like` op This commit adds decomposition of `aten.empty_like` into `aten.empty` op. Signed-Off-by: Gaurav Shukla --- e2e_testing/torchscript/basic.py | 42 +++++++++++++++++-- .../Torch/Transforms/DecomposeComplexOps.cpp | 29 +++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index d4b8f981d..320dda05c 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -663,7 +663,7 @@ class EmptyFloatModule(torch.nn.Module): None, ]) def forward(self): - return torch.abs(torch.empty((3, 4), dtype=torch.float32)) > -1.0 + return torch.pow(torch.empty((3, 4), dtype=torch.float32), 0) @register_test_case(module_factory=lambda: EmptyFloatModule()) def EmptyModule_float(module, tu: TestUtils): @@ -679,8 +679,8 @@ class EmptyFalsePinMemoryModule(torch.nn.Module): None, ]) def forward(self): - return torch.abs(torch.empty((3, 4), dtype=torch.float32, - pin_memory=False)) > -1.0 + return torch.pow(torch.empty((3, 4), dtype=torch.float32, + pin_memory=False), 0) @register_test_case(module_factory=lambda: EmptyFalsePinMemoryModule()) def EmptyModule_falsePinMemory(module, tu: TestUtils): @@ -688,6 +688,42 @@ def EmptyModule_falsePinMemory(module, tu: TestUtils): # ============================================================================== +class EmptyLikeIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, a): + return 0 * torch.empty_like(a, dtype=torch.int64) + +@register_test_case(module_factory=lambda: EmptyLikeIntModule()) +def EmptyLikeModule_int(module, tu: TestUtils): + module.forward(torch.randint(10, (3, 5))) + +# ============================================================================== + +class EmptyLikeFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.pow(torch.empty_like(a, dtype=torch.float32), 0) + +@register_test_case(module_factory=lambda: EmptyLikeFloatModule()) +def EmptyLikeModule_float(module, tu: TestUtils): + module.forward(tu.rand(4, 5)) + +# ============================================================================== + class ContiguousModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7651fd955..a14add937 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -506,6 +506,33 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +// Decompose `aten.empty_like` op into `aten.size` and `aten.empty` ops. +namespace { +class DecomposeAtenEmptyLikeOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEmptyLikeOp op, + PatternRewriter &rewriter) const override { + auto sizeListType = + Torch::ListType::get(Torch::IntType::get(op.getContext())); + Value sizeList = + rewriter.create(op.getLoc(), sizeListType, op.self()); + + // TODO: Handle the case when `dtype` is NoneType. + Type dtype = op.dtype().getType(); + if (dtype.isa() || dtype.isa() || + dtype.isa()) + return rewriter.notifyMatchFailure( + op, "unimplemented: None dtype is not supported"); + + rewriter.replaceOpWithNewOp( + op, op.getType(), sizeList, op.dtype(), op.layout(), op.device(), + op.pin_memory(), op.memory_format()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -521,6 +548,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index d895dd027..df8dd9597 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -237,7 +237,7 @@ public: AtenTanhOp, AtenBatchNormOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, - AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, + AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, AtenEmptyLikeOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,