mirror of https://github.com/llvm/torch-mlir
[Torch] Add fold rule for AtenMaskedFillTensorOp to AtenMaskedFillScalarOp (#2543)
parent
b26797c20b
commit
d67afa9e95
|
@ -2102,55 +2102,6 @@ def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$mask,
|
||||
AnyTorchTensorType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
Torch_NonValueTensorType:$mask,
|
||||
Torch_NonValueTensorType:$value
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenClampOp : Torch_Op<"aten.clamp", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -3658,6 +3609,56 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$mask,
|
||||
AnyTorchTensorType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaskedFillTensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenMaskedFillTensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::masked_fill_.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_NonValueTensorType:$self,
|
||||
Torch_NonValueTensorType:$mask,
|
||||
Torch_NonValueTensorType:$value
|
||||
);
|
||||
let results = (outs
|
||||
Torch_NonValueTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenMaskedFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenMaskedFill_TensorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -162,6 +162,42 @@ static Value getScalarIntValue(Value input, Location loc,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
static Value getScalarFloatValue(Value input, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
auto inputType = input.getType();
|
||||
if (inputType.isa<Torch::FloatType>()) {
|
||||
return input;
|
||||
}
|
||||
|
||||
auto inputTensorType = inputType.dyn_cast<BaseTensorType>();
|
||||
if (!inputTensorType)
|
||||
return nullptr;
|
||||
|
||||
Type inputDtype = inputTensorType.getOptionalDtype();
|
||||
if (!inputDtype ||
|
||||
(!inputDtype.isF16() && !inputDtype.isF32() && !inputDtype.isF64()))
|
||||
return nullptr;
|
||||
|
||||
std::optional<unsigned> inputRank = getTensorRank(input);
|
||||
if (!inputRank || *inputRank != 0)
|
||||
return nullptr;
|
||||
|
||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||
auto val = valueTensorLiteralOp.getValue()
|
||||
.cast<DenseFPElementsAttr>()
|
||||
.getSplatValue<FloatAttr>()
|
||||
.getValueAsDouble();
|
||||
return rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(val));
|
||||
} else if (auto primNumToTensorScalarOp =
|
||||
input.getDefiningOp<PrimNumToTensorScalarOp>()) {
|
||||
return primNumToTensorScalarOp.getA();
|
||||
} else if (auto tensorFloatOp = input.getDefiningOp<AtenTensorFloatOp>()) {
|
||||
return tensorFloatOp.getT();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MethodOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1589,6 +1625,27 @@ OpFoldResult AtenIntBoolOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenMaskedFillTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Fold 0d fill tensor to scalar
|
||||
void AtenMaskedFillTensorOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](AtenMaskedFillTensorOp op, PatternRewriter &rewriter) {
|
||||
auto scalarIntVal =
|
||||
getScalarIntValue(op.getValue(), op->getLoc(), rewriter);
|
||||
auto scalarFloatVal =
|
||||
getScalarFloatValue(op.getValue(), op->getLoc(), rewriter);
|
||||
if (!scalarIntVal && !scalarFloatVal)
|
||||
return failure();
|
||||
Value scalarVal = scalarIntVal ? scalarIntVal : scalarFloatVal;
|
||||
rewriter.replaceOpWithNewOp<AtenMaskedFillScalarOp>(
|
||||
op, op.getType(), op.getSelf(), op.getMask(), scalarVal);
|
||||
return failure();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSortIntOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -300,7 +300,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::le.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||
"aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
|
||||
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
|
||||
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
|
||||
|
@ -337,6 +336,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
|
||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
|
|
|
@ -2136,3 +2136,13 @@ func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !t
|
|||
%0 = torch.aten.numel %arg0 : !torch.vtensor<[3,4],f32> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.masked_fill.Tensor$canonicalize
|
||||
// CHECK-NEXT: torch.constant.float -1.000000e+09
|
||||
// CHECK-NEXT: torch.aten.masked_fill.Scalar
|
||||
// CHECK-NEXT: return
|
||||
func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.vtensor.literal(dense<-1.000000e+09> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
%1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %1 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue