[Torch] Add decompose of AtenToPrimDeviceOp (#3131)

As device information isn't relevant to torch-mlir
pull/3138/head
Xinyu Yang 2024-04-10 22:26:48 +08:00 committed by GitHub
parent 8951a8cc23
commit 5eb0cf9104
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 2 deletions

View File

@ -10529,8 +10529,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<Device>, %arg2: !torch.optional<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" %1 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
" torch.prim.If.yield %3 : !torch.int\n"
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -5450,6 +5450,33 @@ public:
};
} // namespace
namespace {
// Decompose `aten.to.prim_Device` op into `aten.to.dtype` op.
class DecomposeAtenToPrimDeviceOp
: public OpRewritePattern<AtenToPrimDeviceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenToPrimDeviceOp op,
PatternRewriter &rewriter) const override {
// Device information isn't relevant to torch-mlir, so we can drop that info
// here.
auto loc = op.getLoc();
Value constNone = rewriter.create<ConstantNoneOp>(loc);
Value dtype = op.getDtype();
if (dtype.getType().template isa<Torch::NoneType>()) {
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
}
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
dtype, op.getNonBlocking(),
op.getCopy(), constNone);
return success();
}
};
} // namespace
namespace {
// Decompose `aten.to.device` op into `aten.to.dtype` op.
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
@ -7559,6 +7586,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);

View File

@ -475,6 +475,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenPreluOp>();
target.addIllegalOp<AtenToDtypeLayoutOp>();
target.addIllegalOp<AtenToDeviceOp>();
target.addIllegalOp<AtenToPrimDeviceOp>();
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
target.addIllegalOp<AtenClampMinOp>();

View File

@ -2755,7 +2755,9 @@ def atent〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
@check_dtype_function(_check_tensors_with_the_same_dtype(1, tensor_device="meta", device=torch.device("meta")))
def atentoprim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
if dtype is None:
return self_dtype
return dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1))
def atentransposeint〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int: