From 9be89975367176615e8b1d1799777757864ecd78 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Wed, 17 Aug 2022 13:48:10 +0800 Subject: [PATCH] Revert "add native_dropout and related ops pattern (#1211)" (#1230) This reverts commit c9357950864d84ab9a2adf10b2aaa15fec6196e9. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 +- lib/Conversion/TorchToMhlo/Basic.cpp | 100 ------------------ .../Torch/Transforms/DecomposeComplexOps.cpp | 43 -------- .../TorchConversion/Transforms/Passes.cpp | 2 - test/Conversion/TorchToMhlo/dropout.mlir | 47 -------- test/Conversion/TorchToMhlo/view_like.mlir | 12 +-- 6 files changed, 3 insertions(+), 204 deletions(-) delete mode 100644 test/Conversion/TorchToMhlo/dropout.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0417048ff..6b9ead8d5 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5319,10 +5319,9 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ } def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ - NoSideEffect, AllowsTypeRefinement, HasValueSemantics, - ReadOnly, + ReadOnly ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 71ca9de23..2e65d3e09 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -71,25 +71,6 @@ public: }; } // namespace -// ConvertAtenUnaryConvertOp legalize genearl unary ops into Mhlo ConverOp -namespace { -template -class ConvertAtenUnaryConvertOp: public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.self()); - return success(); - } -}; -} // namespace - // aten.ones & aten.zeros // Ref: Error checking based on the Torch to TOSA lowering namespace { @@ -329,9 +310,6 @@ public: std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( op->getContext(), mhlo::ComparisonDirection::GT); - } else if (std::is_same()) { - compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( - op->getContext(), mhlo::ComparisonDirection::GE); } else if (std::is_same() || std::is_same()) { compareDirectionAttr = mhlo::ComparisonDirectionAttr::get( @@ -1005,72 +983,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } // namespace -// AtenSizeIntOp -namespace { -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenSizeIntOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { - // Not a tensor type. - auto selfType = adaptor.self().getType().dyn_cast(); - if (!selfType) - return op.emitError("Only tensor types are currently supported"); - auto dim = rewriter.create( - op.getLoc(), rewriter.getIndexType(), adaptor.dim()); - auto dimSize = rewriter.create( - op.getLoc(), rewriter.getIndexType(), adaptor.self(), dim); - - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), dimSize); - - return success(); -} -} // namespace - -// ValsemVariantAtenUniformOp -namespace { -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - ValsemVariantAtenUniformOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { - auto inputTy = adaptor.self().getType().template cast(); - auto loc = op.getLoc(); - if (!inputTy) { - op.emitError("input should be ranked tensor type."); - } - auto definingOp = op.self().getDefiningOp(); - auto shape = definingOp->getOperand(0); - SmallVector dimSizes; - getListConstructElements(shape, dimSizes); - std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value& dSize) { - dSize = rewriter.create(loc, dSize).getResult(); - return dSize; - }); - - auto mhloShape = - rewriter.create(op.getLoc(), dimSizes); - - double fromDoubleValue, toDoubleValue; - if (!matchPattern(op.from(), m_TorchConstantFloat(&fromDoubleValue))) { - op.emitError("operand #1 should be scalar"); - } - if (!matchPattern(op.to(), m_TorchConstantFloat(&toDoubleValue))) { - op.emitError("operand #2 should be scalar"); - } - Value fromTensor = rewriter.create( - op.getLoc(), - rewriter.getFloatAttr(inputTy.getElementType(), fromDoubleValue)); - Value toTensor = rewriter.create( - op.getLoc(), - rewriter.getFloatAttr(inputTy.getElementType(), toDoubleValue)); - - rewriter.replaceOpWithNewOp( - op, inputTy, fromTensor, toTensor, mhloShape, mhlo::RngDistribution::UNIFORM); - return success(); -} -} void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1096,15 +1008,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenNegOp, mhlo::NegOp); #undef INSERT_UNARY_FPONLY_PATTERN -#define INSERT_UNARY_CONVERT_PATTERN(AtenOp) \ - target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context); - INSERT_UNARY_CONVERT_PATTERN(AtenContiguousOp); - INSERT_UNARY_CONVERT_PATTERN(AtenToDtypeOp); - INSERT_UNARY_CONVERT_PATTERN(AtenTypeAsOp); -#undef INSERT_UNARY_CONVERT_PATTERN - #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ target.addIllegalOp(); \ patterns.add>(typeConverter, \ @@ -1139,7 +1042,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp); - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp); INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp); INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp); @@ -1165,7 +1067,5 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenSizeIntOp); - INSERT_ATENOP_PATTERN(ValsemVariantAtenUniformOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 60cf7c5e7..6d82d95f1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1155,47 +1155,6 @@ public: }; } // namespace -namespace { -class DecomposeAtenNativeDropoutOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenNativeDropoutOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value input = op.input(); - Value prob = op.p(); - bool train = false; - if (!matchPattern(op.train(), m_TorchConstantBool(&train))) - return rewriter.notifyMatchFailure(op, "train must be a boolean constant"); - - BaseTensorType inputType = input.getType().cast(); - if (!train) { - // TODO(yancey.yx): supports inference mode - return op.emitError( - "native_dropout does not support argument train is false"); - } - if (!inputType.hasDtype() || !inputType.getDtype().isa()) - return rewriter.notifyMatchFailure( - op, "only support floating type input for training mode"); - Value noneVal = rewriter.create(loc); - Value floatOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value oneMinusP = rewriter.create(loc, floatOne, prob); - Value boolMask = rewriter.create( - loc, inputType, input, oneMinusP, /*generator=*/noneVal); - Value maskedInput = - rewriter.create(loc, inputType, boolMask, input); - Value output = - rewriter.create(loc, inputType, maskedInput, oneMinusP); - Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - boolMask = rewriter.create( - loc, op.getResult(1).getType(), boolMask, one); - rewriter.replaceOp(op, {output, boolMask}); - return success(); - } -}; -} // namespace // Decompose aten.var into: aten.var.dim op. namespace { class DecomposeAtenVarOp : public OpRewritePattern { @@ -2635,8 +2594,6 @@ class DecomposeComplexOpsPass patterns.add(context); target.addIllegalOp(); patterns.add(context); - patterns.add(context); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 1e6285a7e..c74cc742a 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -139,8 +139,6 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( TorchConversion::createVerifyInvariantsBeforeBackendLoweringPass()); pm.addNestedPass(createConvertTorchToMhloPass()); - pm.addNestedPass(createConvertTorchToSCFPass()); - pm.addNestedPass(createConvertTorchToArithPass()); if (options.optimize) { // Clean up any non-canonical code introduced above.. diff --git a/test/Conversion/TorchToMhlo/dropout.mlir b/test/Conversion/TorchToMhlo/dropout.mlir deleted file mode 100644 index e595b7ecc..000000000 --- a/test/Conversion/TorchToMhlo/dropout.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: torch-mlir-opt < %s --torch-function-to-torch-backend-pipeline --torch-backend-to-mhlo-backend-pipeline -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.native_dropout.train( -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: f64) -> (tensor, tensor) { -// CHECK: %[[T0:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK: %[[CST_0:.*]] = arith.constant 1 : index -// CHECK: %[[CST_1:.*]] = arith.constant 0 : index -// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK: %[[T2:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[CST_2:.*]] = arith.constant 1.000000e+00 : f64 -// CHECK: %[[CST_3:.*]] = arith.subf %[[CST_2]], %[[ARG1]] : f64 -// CHECK: %[[T3:.*]] = tensor.from_elements %[[CST_3]] : tensor<1xf64> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf64>) -> tensor -// CHECK: %[[T5:.*]] = mhlo.convert(%[[ARG0]]) : (tensor) -> tensor -// CHECK: %[[DIM_0:.*]] = tensor.dim %[[T5]], %[[CST_1]] : tensor -// CHECK: %[[CST_I64_0:.*]] = arith.index_cast %[[DIM_0]] : index to i64 -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T5]], %[[CST_0]] : tensor -// CHECK: %[[CST_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[T6:.*]] = tensor.from_elements %[[CST_I64_0]], %[[CST_I64_1]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = "mhlo.rng"(%[[T2]], %[[T1]], %[[T6]]) {rng_distribution = #mhlo.rng_distribution} : (tensor, tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = shape.shape_of %[[T7]] : tensor -> tensor<2xindex> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T4]], %[[T8]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor -// CHECK: %[[T10:.*]] = mhlo.compare LT, %[[T7]], %[[T9]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T11:.*]] = mhlo.convert(%[[T10]]) : (tensor) -> tensor -// CHECK: %[[T12:.*]] = shape.shape_of %[[T11]] : tensor -> tensor<2xindex> -// CHECK: %[[T13:.*]] = shape.shape_of %[[ARG0]] : tensor -> tensor<2xindex> -// CHECK: %[[T14:.*]] = shape.cstr_broadcastable %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -// CHECK: %[[T15:.*]] = shape.assuming %[[T14]] -> (tensor) { -// CHECK: %[[T16:.*]] = shape.broadcast %[[T12]], %[[T13]] : tensor<2xindex>, tensor<2xindex> -> tensor<2xindex> -// CHECK: %[[T17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T11]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor -// CHECK: %[[T18:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[T16]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor, tensor<2xindex>) -> tensor -// CHECK: %[[T19:.*]] = mhlo.multiply %[[T17]], %[[T18]] : tensor -// CHECK: shape.assuming_yield %[[T19]] : tensor -// CHECK: } -// CHECK: %[[T20:.*]] = mhlo.convert(%[[T3]]) : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T21:.*]] = mhlo.reshape %[[T20]] : (tensor<1xf32>) -> tensor -// CHECK: %[[T22:.*]] = shape.shape_of %[[T15]] : tensor -> tensor<2xindex> -// CHECK: %[[T23:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T21]], %[[T22]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor -// CHECK: %[[T24:.*]] = mhlo.multiply %[[T15]], %[[T23]] : tensor -// CHECK: %[[T25:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T12]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor -// CHECK: %[[T26:.*]] = mhlo.compare GE, %[[T11]], %[[T25]], FLOAT : (tensor, tensor) -> tensor -// CHECK: return %[[T24]], %[[T26]] : tensor, tensor -func.func @torch.aten.native_dropout.train(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.float) -> (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>) { - %bool_true = torch.constant.bool true - %result0, %result1 = torch.aten.native_dropout %arg0, %arg1, %bool_true: !torch.vtensor<[?,?],f32>, !torch.float, !torch.bool -> !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> - return %result0, %result1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1> -} diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 37d38d9ae..4783e9940 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -360,17 +360,9 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3,?,?],f32> -> tensor<2x3x?x?xf32> // CHECK: %[[INTneg1:.*]] = torch.constant.int -1 // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[C1_I64:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[C2_I64:.*]] = torch_c.to_i64 %[[INT0]] -// CHECK: %[[INDEX_1:.*]] = arith.index_cast %[[C2_I64]] : i64 to index -// CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[INDEX_1]] : tensor<2x3x?x?xf32> -// CHECK: %[[DIM_I64_1:.*]] = arith.index_cast %[[DIM_1]] : index to i64 -// CHECK: %[[T1:.*]] = torch_c.from_i64 %[[DIM_I64_1]] -// CHECK: %[[INDEX_2:.*]] = arith.index_cast %[[C1_I64]] : i64 to index -// CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[INDEX_2]] : tensor<2x3x?x?xf32> -// CHECK: %[[DIM_I64_2:.*]] = arith.index_cast %[[DIM_2]] : index to i64 -// CHECK: %[[T2:.*]] = torch_c.from_i64 %[[DIM_I64_2]] +// CHECK: %[[T1:.*]] = torch.aten.size.int %[[ARG0]], %[[INT0]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int +// CHECK: %[[T2:.*]] = torch.aten.size.int %[[ARG0]], %[[INT1]] : !torch.vtensor<[2,3,?,?],f32>, !torch.int -> !torch.int // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[T1]], %[[T2]], %[[INTneg1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch_c.to_i64 %[[T1]] // CHECK: %[[T5:.*]] = torch_c.to_i64 %[[T2]]