From 18e8806b14629d9fc56ab53ce3da474ff6b2c693 Mon Sep 17 00:00:00 2001 From: Prateek Gupta Date: Mon, 1 Nov 2021 11:46:46 +0000 Subject: [PATCH] [TORCH][MLIR] Add E2E support for aten::to.dtype. This commit adds end to end support for AtenToDtypeOp from aten to linalg. Signed-Off-By: Prateek Gupta --- e2e_testing/torchscript/elementwise.py | 21 ++++++++++++- .../TorchToLinalg/TorchToLinalg.cpp | 30 +++++++++++++++---- .../Transforms/MaximizeValueSemantics.cpp | 4 +-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 30 +++++++++++++++---- test/Dialect/Torch/refine-types.mlir | 18 +++++++++++ 5 files changed, 89 insertions(+), 14 deletions(-) diff --git a/e2e_testing/torchscript/elementwise.py b/e2e_testing/torchscript/elementwise.py index 7143d25e8..6fdae0271 100644 --- a/e2e_testing/torchscript/elementwise.py +++ b/e2e_testing/torchscript/elementwise.py @@ -311,7 +311,7 @@ class ElementwiseClampModule(torch.nn.Module): def ElementwiseClampModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5, low=-10, high=10)) - +# ============================================================================== class RsubModule(torch.nn.Module): def __init__(self): super().__init__() @@ -344,6 +344,7 @@ class RsubModule_noalpha(torch.nn.Module): def RsubModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) +# ============================================================================== class ElementwiseLogModule(torch.nn.Module): def __init__(self): super().__init__() @@ -410,3 +411,21 @@ class ElementwisePowModule(torch.nn.Module): @register_test_case(module_factory=lambda: ElementwisePowModule()) def ElementwisePowModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True) + ]) + def forward(self, x): + return x.to(torch.int64) + +@register_test_case(module_factory=lambda: ElementwiseToDtypeF32ToI64Module()) +def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 5)) diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 073e78308..f7d5a1217 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -1470,6 +1470,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); } + if (auto atenToDtype = dyn_cast(op)) { + Value input = payloadArgs[0]; + Type inType = input.getType(); + Type outType = atenToDtype.getType().cast().getDtype(); + Value result; + if (!inType.isF32()) { + atenToDtype.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + if (inType == outType) + result = input; + else if (outType.isInteger(64)) + result = b.create(loc, b.getI64Type(), input); + else if (outType.isInteger(1)) + result = b.create(loc, b.getI1Type(), input); + else + atenToDtype.emitError("unimplemented: unsupported target dtype"); + return result; + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1679,8 +1699,8 @@ struct ConvertElementwiseOp : ConversionPattern { if (!isa(op)) + AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, + AtenLogOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2816,8 +2836,9 @@ public: target.addIllegalOp(); + AtenMaximumOp, AtenToDtypeOp, AtenClampOp, + AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp, + AtenPowTensorScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -2857,7 +2878,6 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 1066edc55..bae2a2881 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -91,8 +91,8 @@ public: copyToValueTensorOps.push_back(copyToValueTensor); } else if (isa(op)) { + AtenBroadcastToOp, AtenToDtypeOp, AtenContiguousOp, + AtenPermuteOp, AtenViewOp, AtenExpandOp>(op)) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8aa5ca917..4570fa116 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -226,12 +226,11 @@ public: if (isa(op)) { + AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, + AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, + AtenDetachOp, AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, + AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, AtenClampOp, + AtenRsubScalarOp, AtenLogOp, AtenSqrtOp, AtenFloorOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -346,6 +345,8 @@ public: } else if (auto emptyMemoryFormat = dyn_cast(op)) { return visitConstantTensorAllocOp( emptyMemoryFormat); + } else if (auto toDtype = dyn_cast(op)) { + return visitAtenToDtypeOp(toDtype, operands); } else if (auto toOther = dyn_cast(op)) { return visitTypeConversionOp(toOther, operands); } else if (auto typeAs = dyn_cast(op)) { @@ -480,6 +481,9 @@ private: ChangeResult visitScalarToTensorConversionOp(OpTy op); ChangeResult visitAtenTensorOp(AtenTensorOp op); template ChangeResult visitConstantTensorAllocOp(OpTy op); + ChangeResult + visitAtenToDtypeOp(AtenToDtypeOp op, + ArrayRef *> operands); template ChangeResult visitTypeConversionOp(OpTy op, @@ -1078,6 +1082,20 @@ ChangeResult TypeAnalyzer::visitConstantTensorAllocOp(OpTy op) { return getLatticeElement(op.getResult()).join(knowledge); } +// Convert input tensor type to the given `dtype`. +ChangeResult TypeAnalyzer::visitAtenToDtypeOp( + AtenToDtypeOp op, ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + knowledge.hasSizes = input.hasSizes; + if (input.hasSizes) + knowledge.sizes = input.sizes; + Value dtype = op.dtype(); + fillInDTypeGivenDTypeAndDataType(knowledge, dtype, input.dtype); + return getLatticeElement(op.getResult()).join(knowledge); +} + // Convert input tensor type to the same as the other tensor. template ChangeResult TypeAnalyzer::visitTypeConversionOp( diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 2e8003217..3d2fcf384 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -1004,3 +1004,21 @@ func @aten_matmul_broadcast_vector(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1 %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?],f32> -> !torch.tensor return %0 : !torch.tensor } + +// ----- +// CHECK-LABEL: func @torch.aten.to.dtype +// CHECK-SAME: (%[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor +// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype +// CHECK-SAME: %[[ARG]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : +// CHECK-SAME: !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none +// CHECK-SAME: -> !torch.tensor<[?,?],si64> +// CHECK-NEXT: %[[RES:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.tensor<[?,?],si64> to !torch.tensor +// CHECK-NEXT: return %[[RES]] : !torch.tensor + +func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{ + %none = torch.constant.none + %false = torch.constant.bool false + %int4 = torch.constant.int 4 + %0 = torch.aten.to.dtype %arg0, %int4, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor + return %0 : !torch.tensor +}