[Torch Dialect] Add to.dtype_layout canonicalize patterns (#2062)

* add to.dtype_layout canonicalize patterns

* update comment

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/2087/head snapshot-20230503.827
Zhekun Zhang 2023-05-02 20:06:02 -07:00 committed by GitHub
parent c596d11b98
commit 0cf9ee340b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 76 additions and 1 deletions

View File

@ -7479,6 +7479,7 @@ def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [
}
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
}
def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [

View File

@ -797,6 +797,48 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
return getOperand(0);
}
void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
// `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory
// is false
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
// The pin_memory arg should be either constant `False` or `none`.
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
bool pinMemory;
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
return failure();
else if (pinMemory)
return failure();
}
// The layout arg should be either `none` or `0` i.e. strided.
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
int64_t tensorLayout;
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
return failure();
else if (tensorLayout != torch_upstream::Layout::Strided)
return failure();
}
if (op.getDevice().getType().isa<Torch::NoneType>()) {
// The device arg is `none`. Rewrite to to.dtype.
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
op.getNonBlocking(), op.getCopy(), op.getMemoryFormat());
rewriter.replaceOp(op, toDtype->getResults());
} else {
// The device arg is not `none`. Rewrite to to.device.
AtenToDeviceOp toDevice = rewriter.create<AtenToDeviceOp>(
op.getLoc(), op.getType(), op.getSelf(), op.getDevice(),
op.getDtype(), op.getNonBlocking(), op.getCopy(),
op.getMemoryFormat());
rewriter.replaceOp(op, toDevice->getResults());
}
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenViewOp
//===----------------------------------------------------------------------===//

View File

@ -502,7 +502,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::amax : (Tensor, int[], bool) -> (Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")

View File

@ -1451,6 +1451,38 @@ func.func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32>
return %0 : !torch.tensor<[?,?],f32>
}
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_device(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
// CHECK-NEXT: %[[INT6:.*]] = torch.constant.int 6
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
// CHECK-NEXT: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.device %[[ARG]], %[[CPU]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.Device, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f32>
func.func @torch.aten.to.dtype_layout$to_device(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
%none = torch.constant.none
%device = torch.constant.device "cpu"
%false = torch.constant.bool false
%int6 = torch.constant.int 6
%0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %device, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.Device, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
return %0 : !torch.tensor<[?,?],f32>
}
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_dtype(
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> {
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
// CHECK-NEXT: %[[INT5:.*]] = torch.constant.int 5
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.dtype %[[ARG]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16>
// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f16>
func.func @torch.aten.to.dtype_layout$to_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> {
%none = torch.constant.none
%false = torch.constant.bool false
%int5 = torch.constant.int 5
%0 = torch.aten.to.dtype_layout %arg0, %int5, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16>
return %0 : !torch.tensor<[?,?],f16>
}
// CHECK-LABEL: func.func @torch.aten.ge.float$same_operand(
// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool {
// CHECK: %[[TRUE:.*]] = torch.constant.bool true