mirror of https://github.com/llvm/torch-mlir
[torch-dialect] fix aten.type_as op's folder (#2283)
[torch-dialect] fix torch.type_as op's folder by decomposing it to prim.dtype + aten.to_dtypepull/2326/head
parent
c9add6b7d8
commit
3f843c8fd9
|
@ -767,6 +767,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ToDtypeLayoutNoneModule_basic",
|
||||
"ToDtypeLayoutStridedModule_basic",
|
||||
"TypeAsSameModule_basic",
|
||||
"TypeAsDifferentModule_basic",
|
||||
"TypeConversionF32ToF64Module_basic",
|
||||
"TypeConversionF64ToF32Module_basic",
|
||||
"TypeConversionI1ToF32Module_basic",
|
||||
|
|
|
@ -8001,7 +8001,6 @@ def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
|
|||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
||||
|
|
|
@ -712,20 +712,6 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenTypeAsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTypeAsOp::fold(FoldAdaptor adaptor) {
|
||||
Type inType = getSelf().getType();
|
||||
Type newType = getOther().getType();
|
||||
|
||||
if (inType == newType)
|
||||
return getSelf();
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenToDtypeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -4590,6 +4590,29 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
|
||||
// `torch.to.dtype`.
|
||||
class DecomposeAtenTypeAsOp : public OpRewritePattern<AtenTypeAsOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTypeAsOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto input = op.getSelf();
|
||||
auto other = op.getOther();
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value targetDtype = rewriter.create<Torch::PrimDtypeOp>(loc, other);
|
||||
Value nonBlocking = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value copy = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value memoryFormat = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
|
||||
op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeComplexOpsPass
|
||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||
|
@ -4759,6 +4782,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.useTopDownTraversal = true;
|
||||
|
|
|
@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenTopkOp>();
|
||||
target.addIllegalOp<AtenScalarTensorOp>();
|
||||
target.addIllegalOp<AtenScatterValueOp>();
|
||||
target.addIllegalOp<AtenTypeAsOp>();
|
||||
for (auto &opName : backendLegalOpsSet) {
|
||||
target.addLegalOp(
|
||||
OperationName(kTorchOpPrefix + opName.first().str(), context));
|
||||
|
|
|
@ -521,7 +521,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -235,6 +235,27 @@ class TypeAsSameModule(torch.nn.Module):
|
|||
def TypeAsSameModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
|
||||
class TypeAsDifferentModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.type_as(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeAsDifferentModule())
|
||||
def TypeAsDifferentModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int),
|
||||
tu.randint(3, 5, low=0, high=10, dtype=torch.int64)
|
||||
)
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -1412,14 +1412,6 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
|
|||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as$same(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<[?,?],f32>
|
||||
func.func @torch.aten.type_as$same(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
%0 = torch.aten.type_as %arg0, %arg0 : !torch.tensor<[?,?],f32>, !torch.tensor<[?,?],f32> -> !torch.tensor<[?,?],f32>
|
||||
return %0 : !torch.tensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
|
||||
|
|
|
@ -79,3 +79,27 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor
|
|||
%0 = torch.aten.adaptive_avg_pool2d %arg0, %output_size : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as$basic(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int
|
||||
// CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor
|
||||
// CHECK: return %[[VAR]] : !torch.tensor
|
||||
func.func @torch.aten.type_as$basic(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor {
|
||||
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.type_as$fold(
|
||||
// CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor<[?],f16>, %[[ARG_1:.*]]: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
|
||||
// CHECK: return %[[ARG_0]] : !torch.tensor<[?],f16>
|
||||
func.func @torch.aten.type_as$fold(%arg0: !torch.tensor<[?], f16>, %arg1: !torch.tensor<[?,?],f16>) -> !torch.tensor<[?],f16> {
|
||||
%0 = torch.aten.type_as %arg0, %arg1 : !torch.tensor<[?], f16>, !torch.tensor<[?,?],f16> -> !torch.tensor<[?], f16>
|
||||
return %0 : !torch.tensor<[?], f16>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue