From b2cefc0b642a316946b6368a7f1e444e32b90dae Mon Sep 17 00:00:00 2001 From: Chi_Liu Date: Wed, 21 Dec 2022 08:56:07 -0800 Subject: [PATCH] [TOSA] Add aten.masked_fill.Tensor/Scalar support (#1735) --- e2e_testing/xfail_sets.py | 3 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 68 +++++++++++++++++++ .../test_suite/constant_alloc.py | 43 ++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 36 ++++++++++ 4 files changed, 150 insertions(+) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index eb8850a36..351fad260 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -564,6 +564,8 @@ TOSA_PASS_SET = { "_LogSoftmaxModuleStable_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", + "MaskedFillScalarIntValueStaticModule_basic", + "MaskedFillTensorIntValueStaticModule_basic", "ElementwiseAddScalarInt64Module_basic", "TensorLiteralModule_basic", "TensorOpaqueLiteralModule_basic", @@ -727,6 +729,7 @@ LTC_XFAIL_SET = { "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", + "MaskedFillTensorIntValueStaticModule_basic", "Fill_TensorFloat32WithFloat32_basic", "Fill_TensorFloat32WithFloat64_basic", "Fill_TensorFloat32WithInt64_basic", diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 41a2b9d80..2f18a8011 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3747,6 +3747,67 @@ public: } }; +template +class ConvertAtenMaskedFillOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenOpT::Adaptor; + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto outType = OpConversionPattern::getTypeConverter() + ->convertType(op.getType()) + .template dyn_cast(); + + if (!outType || !outType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "Only Tensor types with static shapes are currently supported"); + + Type outElemTy = outType.getElementType(); + if (!outElemTy.isIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "Only floating-point or integer datatype legalization supported"); + } + + // Not a tensor type. + auto selfType = adaptor.getSelf().getType().template dyn_cast(); + if (!selfType || !outType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, + "Only tensor types with static shapes input are currently supported"); + + auto maskType = adaptor.getMask().getType().template dyn_cast(); + if (!maskType) + return rewriter.notifyMatchFailure( + op, "Only tensor types mask are currently supported"); + + Value rhs = adaptor.getValue(); + auto rhsType = rhs.getType().template dyn_cast(); + Value rhsAsTensor; + if (!rhsType) { // scalar + if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(), + rhsAsTensor, rhs.getType(), {}))) + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "conversion in TOSA operation"); + } else { // tensor + rhsType = rhs.getType().dyn_cast(); + } + + auto rhsTensor = rhsType ? rhs : rhsAsTensor; + auto rhsTensorType = rhsTensor.getType().template dyn_cast(); + if (rhsTensorType.getElementType() != outElemTy) + rhsTensor = rewriter.create( + op.getLoc(), + RankedTensorType::get(rhsTensorType.getShape(), outElemTy), + rhsTensor); + + rewriter.replaceOpWithNewOp(op, outType, adaptor.getMask(), + rhsTensor, adaptor.getSelf()); + return success(); + } +}; + // Legalizes the torch.clone op. template class ConvertAtenCloneOp : public OpConversionPattern { @@ -3947,6 +4008,13 @@ public: INSERT_FILL_SCALAR_PATTERN(AtenFill_ScalarOp); #undef INSERT_FILL_SCALAR_PATTERN +#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); +#undef INSERT_MASKED_FILL_PATTERN + #define INSERT_ATENOP_PATTERN(AtenOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context); diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 1572bf73b..43157750b 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -1402,3 +1402,46 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module): def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils): module.forward(tu.randint(2, 3, low=-10, high=10), tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.rand()) + + +class MaskedFillScalarIntValueStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.bool, True), + ]) + def forward(self, x, mask): + return torch.ops.aten.masked_fill(x, mask, value=5) + + +@register_test_case(module_factory=lambda: MaskedFillScalarIntValueStaticModule()) +def MaskedFillScalarIntValueStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3), + tu.randint(2, 3, high=2).to(dtype=torch.bool)) + + +class MaskedFillTensorIntValueStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.bool, True), + ([], torch.int64, True), + ]) + def forward(self, x, mask, value): + return torch.ops.aten.masked_fill(x, mask, value=value) + + +@register_test_case(module_factory=lambda: MaskedFillTensorIntValueStaticModule()) +def MaskedFillTensorIntValueStaticModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3), + tu.randint(2, 3, high=2).to(dtype=torch.bool), tu.randint()) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index aba391834..5f035e5f3 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -961,6 +961,42 @@ func.func @torch.aten.clamp(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.masked_fill.Scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<0> : tensor} : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.cast"(%[[VAL_5]]) : (tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = "tosa.select"(%[[VAL_3]], %[[VAL_6]], %[[VAL_2]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_8]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.masked_fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>) -> !torch.vtensor<[1,12,128,128],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.masked_fill.Scalar %arg0, %arg1, %int0 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.int -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,12,128,128],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,128,128],i1>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,12,128,128],f32> -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,1,128,128],i1> -> tensor<1x1x128x128xi1> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[],f32> -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.select"(%[[VAL_4]], %[[VAL_5]], %[[VAL_3]]) : (tensor<1x1x128x128xi1>, tensor, tensor<1x12x128x128xf32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: } +func.func @torch.aten.masked_fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1,1,128,128],i1>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[1,12,128,128],f32> { + %0 = torch.aten.masked_fill.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1,1,128,128],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,128,128],f32> + return %0 : !torch.vtensor<[1,12,128,128],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.where.self( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,5,5],i1>,