mirror of https://github.com/llvm/torch-mlir
[Torch] Add decompose of AtenToPrimDeviceOp (#3131)
As device information isn't relevant to torch-mlirpull/3138/head
parent
8951a8cc23
commit
5eb0cf9104
|
@ -10529,8 +10529,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.to.prim_Device\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<Device>, %arg2: !torch.optional<int>, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" %1 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %0#1 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %3 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.transpose.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -5450,6 +5450,33 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.to.prim_Device` op into `aten.to.dtype` op.
|
||||
class DecomposeAtenToPrimDeviceOp
|
||||
: public OpRewritePattern<AtenToPrimDeviceOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenToPrimDeviceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
// Device information isn't relevant to torch-mlir, so we can drop that info
|
||||
// here.
|
||||
auto loc = op.getLoc();
|
||||
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
Value dtype = op.getDtype();
|
||||
if (dtype.getType().template isa<Torch::NoneType>()) {
|
||||
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
||||
dtype, op.getNonBlocking(),
|
||||
op.getCopy(), constNone);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose `aten.to.device` op into `aten.to.dtype` op.
|
||||
class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
|
||||
|
@ -7559,6 +7586,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenPadOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
|
||||
|
|
|
@ -475,6 +475,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenPreluOp>();
|
||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||
target.addIllegalOp<AtenToDeviceOp>();
|
||||
target.addIllegalOp<AtenToPrimDeviceOp>();
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
|
||||
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
|
||||
target.addIllegalOp<AtenClampMinOp>();
|
||||
|
|
|
@ -2755,7 +2755,9 @@ def aten〇t〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
@check_dtype_function(_check_tensors_with_the_same_dtype(1, tensor_device="meta", device=torch.device("meta")))
|
||||
def aten〇to〇prim_Device〡dtype(self_rank_dtype: Tuple[int, int], device: Optional[device], dtype: Optional[int] = None, non_blocking: bool = False, copy: bool = False) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
if dtype is None:
|
||||
return self_dtype
|
||||
return dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim0=0, dim1=1))
|
||||
def aten〇transpose〇int〡dtype(self_rank_dtype: Tuple[int, int], dim0: int, dim1: int) -> int:
|
||||
|
|
Loading…
Reference in New Issue