mirror of https://github.com/llvm/torch-mlir
[torch-dialect] fix aten.type_as op's folder (#2283)
[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtypepull/2326/head
parent
c9add6b7d8
commit
3f843c8fd9
|
@ -767,6 +767,7 @@ STABLEHLO_PASS_SET = {
|
||||||
"ToDtypeLayoutNoneModule_basic",
|
"ToDtypeLayoutNoneModule_basic",
|
||||||
"ToDtypeLayoutStridedModule_basic",
|
"ToDtypeLayoutStridedModule_basic",
|
||||||
"TypeAsSameModule_basic",
|
"TypeAsSameModule_basic",
|
||||||
|
"TypeAsDifferentModule_basic",
|
||||||
"TypeConversionF32ToF64Module_basic",
|
"TypeConversionF32ToF64Module_basic",
|
||||||
"TypeConversionF64ToF32Module_basic",
|
"TypeConversionF64ToF32Module_basic",
|
||||||
"TypeConversionI1ToF32Module_basic",
|
"TypeConversionI1ToF32Module_basic",
|
||||||
|
|
|
@ -8001,7 +8001,6 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
let hasFolder = 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
||||||
|
|
|
@ -712,20 +712,6 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// AtenTypeAsOp
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
|
|
||||||
Type inType = getSelf().getType();
|
|
||||||
Type newType = getOther().getType();
|
|
||||||
|
|
||||||
if (inType == newType)
|
|
||||||
return getSelf();
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenToDtypeOp
|
// AtenToDtypeOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -4590,6 +4590,29 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
|
||||||
|
// `torch.to.dtype`.
|
||||||
|
class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTypeAsOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto input = op.getSelf();
|
||||||
|
auto other = op.getOther();
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
|
Value targetDtype = rewriter.create<Torch::PrimDtypeOp>(loc, other);
|
||||||
|
Value nonBlocking = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
Value copy = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||||
|
Value memoryFormat = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||||
|
op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -4759,6 +4782,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
|
|
|
@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenTopkOp>();
|
target.addIllegalOp<AtenTopkOp>();
|
||||||
target.addIllegalOp<AtenScalarTensorOp>();
|
target.addIllegalOp<AtenScalarTensorOp>();
|
||||||
target.addIllegalOp<AtenScatterValueOp>();
|
target.addIllegalOp<AtenScatterValueOp>();
|
||||||
|
target.addIllegalOp<AtenTypeAsOp>();
|
||||||
for (auto &opName : backendLegalOpsSet) {
|
for (auto &opName : backendLegalOpsSet) {
|
||||||
target.addLegalOp(
|
target.addLegalOp(
|
||||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||||
|
|
|
@ -521,7 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
|
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
|
||||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -235,6 +235,27 @@ class TypeAsSameModule(torch.nn.Module):
|
||||||
def TypeAsSameModule_basic(module, tu: TestUtils):
|
def TypeAsSameModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||||
|
|
||||||
|
class TypeAsDifferentModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.int, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.type_as(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TypeAsDifferentModule())
|
||||||
|
def TypeAsDifferentModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(
|
||||||
|
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
|
||||||
|
tu.randint(3, 5, low=0, high=10, dtype=torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
|
@ -1412,14 +1412,6 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
|
||||||
return %0 : !torch.tensor<[],f32>
|
return %0 : !torch.tensor<[],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.type_as$same(
|
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
|
|
||||||
func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
|
||||||
%0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
|
|
||||||
return %0 : !torch.tensor<[?,?],f32>
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
|
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
||||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
|
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
|
||||||
|
|
|
@ -79,3 +79,27 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor
|
||||||
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
|
||||||
|
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
|
||||||
|
// CHECK: return %[[VAR]] : !torch.tensor
|
||||||
|
func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
||||||
|
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor
|
||||||
|
return %0 : !torch.tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.type_as$fold(
|
||||||
|
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
|
||||||
|
// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16>
|
||||||
|
func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
|
||||||
|
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
||||||
|
return %0 : !torch.tensor<[?], f16>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue