[torch-dialect] fix aten.type_as op's folder (#2283)

[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtype
pull/2326/head
Jiawei Wu 2023-07-20 09:51:58 +08:00 committed by GitHub
parent c9add6b7d8
commit 3f843c8fd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 72 additions and 24 deletions

View File

@ -767,6 +767,7 @@ STABLEHLO_PASS_SET = {
"ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutNoneModule_basic",
"ToDtypeLayoutStridedModule_basic", "ToDtypeLayoutStridedModule_basic",
"TypeAsSameModule_basic", "TypeAsSameModule_basic",
"TypeAsDifferentModule_basic",
"TypeConversionF32ToF64Module_basic", "TypeConversionF32ToF64Module_basic",
"TypeConversionF64ToF32Module_basic", "TypeConversionF64ToF32Module_basic",
"TypeConversionI1ToF32Module_basic", "TypeConversionI1ToF32Module_basic",

View File

@ -8001,7 +8001,6 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
printDefaultTorchOp(printer, *this, 2, 1); printDefaultTorchOp(printer, *this, 2, 1);
} }
}]; }];
let hasFolder = 1;
} }
def Torch_AtenViewOp : Torch_Op<"aten.view", [ def Torch_AtenViewOp : Torch_Op<"aten.view", [

View File

@ -712,20 +712,6 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
return nullptr; return nullptr;
} }
//===----------------------------------------------------------------------===//
// AtenTypeAsOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
Type inType = getSelf().getType();
Type newType = getOther().getType();
if (inType == newType)
return getSelf();
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AtenToDtypeOp // AtenToDtypeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -4590,6 +4590,29 @@ public:
}; };
} // namespace } // namespace
namespace {
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
// `torch.to.dtype`.
class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTypeAsOp op,
PatternRewriter &rewriter) const override {
auto input = op.getSelf();
auto other = op.getOther();
Location loc = op.getLoc();
Value targetDtype = rewriter.create<Torch::PrimDtypeOp>(loc, other);
Value nonBlocking = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value copy = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value memoryFormat = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat);
return success();
}
};
} // namespace
namespace { namespace {
class DecomposeComplexOpsPass class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> { : public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -4759,6 +4782,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
GreedyRewriteConfig config; GreedyRewriteConfig config;
config.useTopDownTraversal = true; config.useTopDownTraversal = true;

View File

@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenTopkOp>(); target.addIllegalOp<AtenTopkOp>();
target.addIllegalOp<AtenScalarTensorOp>(); target.addIllegalOp<AtenScalarTensorOp>();
target.addIllegalOp<AtenScatterValueOp>(); target.addIllegalOp<AtenScatterValueOp>();
target.addIllegalOp<AtenTypeAsOp>();
for (auto &opName : backendLegalOpsSet) { for (auto &opName : backendLegalOpsSet) {
target.addLegalOp( target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context)); OperationName(kTorchOpPrefix + opName.first().str(), context));

View File

@ -521,7 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)") emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)") emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True) emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")

View File

@ -235,6 +235,27 @@ class TypeAsSameModule(torch.nn.Module):
def TypeAsSameModule_basic(module, tu: TestUtils): def TypeAsSameModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 5), tu.rand(3, 5)) module.forward(tu.rand(3, 5), tu.rand(3, 5))
class TypeAsDifferentModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.type_as(x, y)
@register_test_case(module_factory=lambda: TypeAsDifferentModule())
def TypeAsDifferentModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
tu.randint(3, 5, low=0, high=10, dtype=torch.int64)
)
# ============================================================================== # ==============================================================================

View File

@ -1412,14 +1412,6 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
return %0 : !torch.tensor<[],f32> return %0 : !torch.tensor<[],f32>
} }
// CHECK-LABEL: func.func @torch.aten.type_as$same(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
%0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
return %0 : !torch.tensor<[?,?],f32>
}
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>

View File

@ -79,3 +79,27 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
// CHECK: return %[[VAR]] : !torch.tensor
func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor
return %0 : !torch.tensor
}
// -----
// CHECK-LABEL: func.func @torch.aten.type_as$fold(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16>
func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
return %0 : !torch.tensor<[?], f16>
}