From ab81f871e4db3bd080697c2b9be03d36a06f6c42 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Tue, 14 Dec 2021 17:45:07 +0530 Subject: [PATCH] Add aten.tensor.int and aten.tensor.float op lowerings. Add the required lowerings and correct test cases. These op produce zero-d tensors and it was incorrectly mentioned in refine types to produce 1d tensor of size 1. --- e2e_testing/torchscript/basic.py | 35 ++++++++++++ .../TorchToLinalg/TorchToLinalg.cpp | 56 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 11 ++-- test/Dialect/Torch/refine-types.mlir | 8 +-- 4 files changed, 100 insertions(+), 10 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index bec21657e..d4b8f981d 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -835,6 +835,41 @@ def AddCDivModule_basic(module, tu: TestUtils): # ============================================================================== +class tensorIntModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + + def forward(self): + a = 1 + return torch.tensor(a) + +@register_test_case(module_factory=lambda: tensorIntModule()) +def TensorIntModule_basic(module, tu: TestUtils): + module.forward() + +class tensorFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + + def forward(self): + a = 1.0 + return torch.tensor(a) + +@register_test_case(module_factory=lambda: tensorFloatModule()) +def TensorFloatModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== class DropoutModule(torch.nn.Module): def __init__(self): diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7b832dcac..7fca15acf 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -3451,6 +3451,60 @@ public: }; } // namespace +namespace { +struct ConvertAtenScalarToTensorLike : ConversionPattern { + ConvertAtenScalarToTensorLike(TypeConverter &typeConverter, + MLIRContext *context) + : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, + context) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!isa(op)) + return rewriter.notifyMatchFailure( + op, "not a supported Scalar to Tensor like op"); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + Location loc = op->getLoc(); + Value elemVal, dtype, device, requires_grad; + if (AtenTensorIntOp tensorIntOp = dyn_cast(op)) { + AtenTensorIntOp::Adaptor adaptor(operands); + elemVal = adaptor.t(); + dtype = tensorIntOp.dtype(); + device = tensorIntOp.device(); + requires_grad = tensorIntOp.requires_grad(); + } + if (AtenTensorFloatOp tensorFloatOp = dyn_cast(op)) { + AtenTensorFloatOp::Adaptor adaptor(operands); + elemVal = adaptor.t(); + dtype = tensorFloatOp.dtype(); + device = tensorFloatOp.device(); + requires_grad = tensorFloatOp.requires_grad(); + } + // TODO: Dtype conversion. + if (!dtype.getType().isa()) + return rewriter.notifyMatchFailure(op, "Unimplemented non-None dtype"); + + // TODO: Device information. + if (!device.getType().isa()) + return rewriter.notifyMatchFailure( + op, "Unimplemented non-None device information"); + + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Type outElementType = resultType.getElementType(); + Value elemValProm = + convertScalarToDtype(rewriter, loc, elemVal, outElementType); + Value zeroDTensor = + createInitTensor(rewriter, loc, {}, outElementType, elemValProm); + rewriter.replaceOpWithNewOp(op, resultType, zeroDTensor); + return success(); + } +}; +} // namespace + namespace { // Converts constant tensor allocation like ops. template @@ -3751,6 +3805,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + target.addIllegalOp(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index a0142683f..80bf0b395 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -238,11 +238,11 @@ public: AtenGeluBackwardOp, AtenBitwiseNotOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp, DerefineOp, AtenToPrimDeviceOp, AtenCpuOp, AtenContiguousOp, AtenFill_ScalarOp, AtenDetachOp, - AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenIndexPut_Op, AtenCumsumOp, - AtenLayerNormOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, - AtenFloorOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, - AtenDropoutOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, - AtenAddIntOp, AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) { + AtenMaskedFill_ScalarOp, AtenCopy_Op, AtenCumsumOp, AtenLayerNormOp, + AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, + AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, + AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, + AtenAbsOp, AtenReciprocalOp, AtenCeilOp>(op)) { return getLatticeElement(op->getResult(0)).join(*operands[0]); } @@ -1272,7 +1272,6 @@ ChangeResult TypeAnalyzer::visitScalarToTensorConversionOp(OpTy op) { Value t = op.t(); Value dtype = op.dtype(); knowledge.hasSizes = true; - knowledge.sizes.resize(1, 1); fillInDTypeGivenDTypeAndDataType(knowledge, dtype, t.getType()); return getLatticeElement(op.getResult()).join(knowledge); } diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index 6c475099b..ca9e2b77a 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -593,8 +593,8 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[1],f32> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],f32> to !torch.tensor +// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.none, !torch.none, !torch.bool -> !torch.tensor<[],f32> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],f32> to !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { @@ -611,8 +611,8 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor { // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[CST11:.*]] = torch.constant.int 11 // CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[1],i1> -// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[1],i1> to !torch.tensor +// CHECK: %[[RET:.*]] = torch.aten.tensor.float %[[t]], %[[CST11]], %[[NONE]], %[[FALSE]] : !torch.float, !torch.int, !torch.none, !torch.bool -> !torch.tensor<[],i1> +// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[RET]] : !torch.tensor<[],i1> to !torch.tensor // CHECK: return %[[CAST]] : !torch.tensor builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torch.tensor {