mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Add to.dtype_layout canonicalize patterns (#2062)
* add to.dtype_layout canonicalize patterns * update comment --------- Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>pull/2087/head snapshot-20230503.827
parent
c596d11b98
commit
0cf9ee340b
|
@ -7479,6 +7479,7 @@ def Torch_AtenToDtypeLayoutOp : Torch_Op<"aten.to.dtype_layout", [
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [
|
||||
|
|
|
@ -797,6 +797,48 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
|||
return getOperand(0);
|
||||
}
|
||||
|
||||
void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
// `to.dtype_layout` -> `to.device/to.dtype` if layout is none and pin memory
|
||||
// is false
|
||||
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
||||
// The pin_memory arg should be either constant `False` or `none`.
|
||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
||||
bool pinMemory;
|
||||
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||
return failure();
|
||||
else if (pinMemory)
|
||||
return failure();
|
||||
}
|
||||
|
||||
// The layout arg should be either `none` or `0` i.e. strided.
|
||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
||||
int64_t tensorLayout;
|
||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||
return failure();
|
||||
else if (tensorLayout != torch_upstream::Layout::Strided)
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (op.getDevice().getType().isa<Torch::NoneType>()) {
|
||||
// The device arg is `none`. Rewrite to to.dtype.
|
||||
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
||||
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
||||
op.getNonBlocking(), op.getCopy(), op.getMemoryFormat());
|
||||
rewriter.replaceOp(op, toDtype->getResults());
|
||||
} else {
|
||||
// The device arg is not `none`. Rewrite to to.device.
|
||||
AtenToDeviceOp toDevice = rewriter.create<AtenToDeviceOp>(
|
||||
op.getLoc(), op.getType(), op.getSelf(), op.getDevice(),
|
||||
op.getDtype(), op.getNonBlocking(), op.getCopy(),
|
||||
op.getMemoryFormat());
|
||||
rewriter.replaceOp(op, toDevice->getResults());
|
||||
}
|
||||
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -502,7 +502,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::amax : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True)
|
||||
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True)
|
||||
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||
|
|
|
@ -1451,6 +1451,38 @@ func.func @torch.aten.to.dtype_layout$same_dtype(%arg0: !torch.tensor<[?,?],f32>
|
|||
return %0 : !torch.tensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_device(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
// CHECK-NEXT: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK-NEXT: %[[CPU:.*]] = torch.constant.device "cpu"
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.device %[[ARG]], %[[CPU]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.Device, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
|
||||
// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f32>
|
||||
func.func @torch.aten.to.dtype_layout$to_device(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f32> {
|
||||
%none = torch.constant.none
|
||||
%device = torch.constant.device "cpu"
|
||||
%false = torch.constant.bool false
|
||||
%int6 = torch.constant.int 6
|
||||
%0 = torch.aten.to.dtype_layout %arg0, %int6, %none, %device, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.Device, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f32>
|
||||
return %0 : !torch.tensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype_layout$to_dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> {
|
||||
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK-NEXT: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-NEXT: %[[INT5:.*]] = torch.constant.int 5
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = torch.aten.to.dtype %[[ARG]], %[[INT5]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor<[?,?],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16>
|
||||
// CHECK-NEXT: return %[[RESULT]] : !torch.tensor<[?,?],f16>
|
||||
func.func @torch.aten.to.dtype_layout$to_dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor<[?,?],f16> {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int5 = torch.constant.int 5
|
||||
%0 = torch.aten.to.dtype_layout %arg0, %int5, %none, %none, %none, %false, %false, %none : !torch.tensor<[?,?],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.bool, !torch.none -> !torch.tensor<[?,?],f16>
|
||||
return %0 : !torch.tensor<[?,?],f16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.ge.float$same_operand(
|
||||
// CHECK-SAME: %{{.*}}: !torch.float) -> !torch.bool {
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
|
|
Loading…
Reference in New Issue