From 163fa57cde44dd6a8f14e422d341d8deb4a40c62 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Sat, 25 Jun 2022 07:27:47 -0700 Subject: [PATCH] torch: allow torch dialect ops after running drop-shape pass (#979) In the `pyhpc_turbulent_kinetic_energy` TorchBench benchmark, the shape calculation occurs inside loops, but because `DropShapeCalculationsPass` does not explicitly mark the Torch dialect as legal, the pass execution fails. This patch adds Torch to the list of legal dialects, and adds a test to validate the translation. --- .../Transforms/DropShapeCalculations.cpp | 1 + .../Torch/drop-shape-calculations.mlir | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp index aeb4b3dfb..d21a7100d 100644 --- a/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/DropShapeCalculations.cpp @@ -53,6 +53,7 @@ class DropShapeCalculationsPass RewritePatternSet patterns(context); patterns.insert(context); ConversionTarget target(*context); + target.addLegalDialect(); target.addIllegalOp(); target.addLegalOp(); diff --git a/test/Dialect/Torch/drop-shape-calculations.mlir b/test/Dialect/Torch/drop-shape-calculations.mlir index 9aae69aea..3a0ab8435 100644 --- a/test/Dialect/Torch/drop-shape-calculations.mlir +++ b/test/Dialect/Torch/drop-shape-calculations.mlir @@ -19,3 +19,39 @@ func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor { %1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[2,?],unk> to !torch.vtensor return %1 : !torch.vtensor } + +// ----- + +// CHECK-LABEL: func.func @shape_calc_in_loop( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> { +func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> { + %one = torch.constant.int 1 + // CHECK: %[[ONE:.*]] = torch.constant.int 1 + + %two = torch.constant.int 2 + // CHECK: %[[TWO:.*]] = torch.constant.int 2 + + %true = torch.constant.bool true + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + + torch.prim.Loop %one, %true, init() { + // CHECK: torch.prim.Loop %[[ONE]], %[[TRUE]], init() { + + ^bb0(%in: !torch.int): + %shape_calc = torch.shape.calculate { + %tanh = torch.aten.tanh %arg : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk> + torch.shape.calculate.yield %tanh : !torch.vtensor<[2,?],unk> + } shapes { + %size = torch.aten.size.int %arg, %one : !torch.vtensor<[2,?],unk>, !torch.int -> !torch.int + %list = torch.prim.ListConstruct %two, %size : (!torch.int, !torch.int) -> !torch.list + torch.shape.calculate.yield.shapes %list : !torch.list + } : !torch.vtensor<[2,?],unk> + // CHECK: torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk> + + torch.prim.Loop.condition %true, iter() + // CHECK: torch.prim.Loop.condition %[[TRUE]], iter() + } : (!torch.int, !torch.bool) -> () + + return %arg : !torch.vtensor<[2,?],unk> + // CHECK: return %[[ARG]] : !torch.vtensor<[2,?],unk> +}