[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543)

pull/2573/head snapshot-20231121.1029
Zhekun(Josh) Zhang 2023-11-21 13:26:17 +08:00 committed by GitHub
parent b26797c20b
commit d67afa9e95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 118 additions and 50 deletions

View File

@ -2102,55 +2102,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
}];
}
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$mask,
Torch_NonValueTensorType:$value
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
AllowsTypeRefinement,
HasValueSemantics,
@ -3658,6 +3609,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
}];
}
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mask,
AnyTorchTensorType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$mask,
Torch_NonValueTensorType:$value
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc,
return nullptr;
}
static Value getScalarFloatValue(Value input, Location loc,
PatternRewriter &rewriter) {
auto inputType = input.getType();
if (inputType.isa<Torch::FloatType>()) {
return input;
}
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
if (!inputTensorType)
return nullptr;
Type inputDtype = inputTensorType.getOptionalDtype();
if (!inputDtype ||
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
return nullptr;
std::optional<unsigned> inputRank = getTensorRank(input);
if (!inputRank || *inputRank != 0)
return nullptr;
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
auto val = valueTensorLiteralOp.getValue()
.cast<DenseFPElementsAttr>()
.getSplatValue<FloatAttr>()
.getValueAsDouble();
return rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(val));
} else if (auto primNumToTensorScalarOp =
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
return primNumToTensorScalarOp.getA();
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
return tensorFloatOp.getT();
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// MethodOp
//===----------------------------------------------------------------------===//
@ -1589,6 +1625,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenMaskedFillTensorOp
//===----------------------------------------------------------------------===//
// Fold 0d fill tensor to scalar
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
auto scalarIntVal =
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
auto scalarFloatVal =
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
if (!scalarIntVal && !scalarFloatVal)
return failure();
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
return failure();
});
}
//===----------------------------------------------------------------------===//
// AtenSortIntOp
//===----------------------------------------------------------------------===//

View File

@ -300,7 +300,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
@ -337,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")

View File

@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t
%0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int
return %0 : !torch.int
}
// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize
// CHECK-NEXT: torch.constant.float -1.000000e+09
// CHECK-NEXT: torch.aten.masked_fill.Scalar
// CHECK-NEXT: return
func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor<f32>) : !torch.vtensor<[],f32>
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}