From d67afa9e951128b6d3973044ea45d5daa127de4b Mon Sep 17 00:00:00 2001 From: "Zhekun(Josh) Zhang" <32320144+zhekunz2@users.noreply.github.com> Date: Tue, 21 Nov 2023 13:26:17 +0800 Subject: [PATCH] [Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 99 ++++++++++--------- lib/Dialect/Torch/IR/TorchOps.cpp | 57 +++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 10 ++ 4 files changed, 118 insertions(+), 50 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0c3efd6ce..4ecac580c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2102,55 +2102,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ }]; } -def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$mask, - AnyTorchTensorType:$value - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - -def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$mask, - Torch_NonValueTensorType:$value - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, @@ -3658,6 +3609,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ }]; } +def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$value + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasCanonicalizer = 1; +} + +def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$mask, + Torch_NonValueTensorType:$value + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c7dc571b6..111f1a7b6 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; } +static Value getScalarFloatValue(Value input, Location loc, + PatternRewriter &rewriter) { + auto inputType = input.getType(); + if (inputType.isa()) { + return input; + } + + auto inputTensorType = inputType.dyn_cast(); + if (!inputTensorType) + return nullptr; + + Type inputDtype = inputTensorType.getOptionalDtype(); + if (!inputDtype || + (!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64())) + return nullptr; + + std::optional inputRank = getTensorRank(input); + if (!inputRank || *inputRank != 0) + return nullptr; + + if (auto valueTensorLiteralOp = input.getDefiningOp()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue() + .getValueAsDouble(); + return rewriter.create( + loc, rewriter.getF64FloatAttr(val)); + } else if (auto primNumToTensorScalarOp = + input.getDefiningOp()) { + return primNumToTensorScalarOp.getA(); + } else if (auto tensorFloatOp = input.getDefiningOp()) { + return tensorFloatOp.getT(); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // MethodOp //===----------------------------------------------------------------------===// @@ -1589,6 +1625,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenMaskedFillTensorOp +//===----------------------------------------------------------------------===// + +// Fold 0d fill tensor to scalar +void AtenMaskedFillTensorOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) { + auto scalarIntVal = + getScalarIntValue(op.getValue(), op->getLoc(), rewriter); + auto scalarFloatVal = + getScalarFloatValue(op.getValue(), op->getLoc(), rewriter); + if (!scalarIntVal && !scalarFloatVal) + return failure(); + Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal; + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), op.getMask(), scalarVal); + return failure(); + }); +} + //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ff78d463a..557a2fadb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -300,7 +300,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", @@ -337,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 82535062a..5dfd8daa9 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t %0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int return %0 : !torch.int } + +// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize +// CHECK-NEXT: torch.constant.float -1.000000e+09 +// CHECK-NEXT: torch.aten.masked_fill.Scalar +// CHECK-NEXT: return +func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor) : !torch.vtensor<[],f32> + %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +}