From b1a506624ce64a8f3ae2f878d5d35cd3f0dfae1d Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Mon, 8 Aug 2022 20:57:11 +0530 Subject: [PATCH] Add decomposition of `aten.masked.tensor` op. `aten.masked.tensor` op has been decomposed to `aten.masked.scalar` op. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 49 +++++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 20 ++++++-- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 3 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 ++ .../jit_ir/build_tools/shape_lib_gen.py | 3 ++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../test_suite/constant_alloc.py | 25 +++++++++- 7 files changed, 99 insertions(+), 6 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 20f1f9856..587f9cc41 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1928,6 +1928,55 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ }]; } +def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index a802f1d5d..1e5150d46 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -884,9 +884,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( threshold); return b.create(loc, predicate, constantZero, grad); } - if (auto maskedFill = dyn_cast(op)) { + if (auto maskedFillScalar = dyn_cast(op)) { AtenMaskedFillScalarOp::Adaptor adaptor(operands); - Type dtype = converter->convertType(maskedFill.getType()) + Type dtype = converter->convertType(maskedFillScalar.getType()) .cast() .getElementType(); @@ -896,6 +896,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, mask, fillValue, input); } + if (auto maskedFillTensor = dyn_cast(op)) { + AtenMaskedFillScalarOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(maskedFillTensor.getType()) + .cast() + .getElementType(); + + Value input = payloadArgs[0]; + Value mask = payloadArgs[1]; + Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); + return b.create(loc, mask, fillValue, input); + } if (auto triu = dyn_cast(op)) { // Check if the rank of the input tensor is valid. @@ -970,7 +981,7 @@ public: AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, - AtenLogicalOrOp, AtenTriuOp>(op)) + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -1708,7 +1719,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); + AtenNeScalarOp, AtenMaskedFillScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index e72f65e92..c5e865790 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -658,7 +658,8 @@ void TypeAnalysis::visitOperation(Operation *op, AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, - PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) { + PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp>( + op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); } diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index de7c14641..eaa848a0e 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6214,6 +6214,10 @@ module { %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.masked_fill.Tensor"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list { + %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.zero"(%arg0: !torch.list) -> !torch.list { return %arg0 : !torch.list } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 203054d98..8f756cec2 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -777,6 +777,9 @@ def aten〇_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option def aten〇masked_fill〇Scalar(self: List[int], mask: List[int], value: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇masked_fill〇Tensor(self: List[int], mask: List[int], value: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇zero(self: List[int]) -> List[int]: return self diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a09e5367f..973ee2a9f 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -279,6 +279,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", + "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 44282a879..fad09c4a8 100644 --- a/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -342,7 +342,8 @@ class EmptyLikeMemoryFormatModule(torch.nn.Module): ([-1, -1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.empty_like(a, memory_format=torch.preserve_format).fill_(0) + return torch.empty_like(a, + memory_format=torch.preserve_format).fill_(0) @register_test_case(module_factory=lambda: EmptyLikeMemoryFormatModule()) @@ -1421,3 +1422,25 @@ class MaskedFillScalarFloatValueModule(torch.nn.Module): def MaskedFillScalarFloatValueModule_basic(module, tu: TestUtils): module.forward(torch.randint(-10, 10, (2, 3)), torch.randint(0, 2, (2, 3)).to(dtype=torch.bool)) + + +class MaskedFillTensorFloatValueModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.bool, True), + ([], torch.float32, True), + ]) + def forward(self, x, mask, value): + return torch.ops.aten.masked_fill(x, mask, value=value) + + +@register_test_case(module_factory=lambda: MaskedFillTensorFloatValueModule()) +def MaskedFillTensorFloatValueModule_basic(module, tu: TestUtils): + module.forward(torch.randint(-10, 10, (2, 3)), + torch.randint(0, 2, (2, 3)).to(dtype=torch.bool), tu.rand())