mirror of https://github.com/llvm/torch-mlir
torch: allow torch dialect ops after running drop-shape pass (#979)
In the `pyhpc_turbulent_kinetic_energy` TorchBench benchmark, the shape calculation occurs inside loops, but because `DropShapeCalculationsPass` does not explicitly mark the Torch dialect as legal, the pass execution fails. This patch adds Torch to the list of legal dialects, and adds a test to validate the translation.pull/977/merge snapshot-20220626.515
parent
1be604bfd3
commit
163fa57cde
|
@ -53,6 +53,7 @@ class DropShapeCalculationsPass
|
|||
RewritePatternSet patterns(context);
|
||||
patterns.insert<DropShapeCalculateOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<Torch::TorchDialect>();
|
||||
target.addIllegalOp<ShapeCalculateOp>();
|
||||
target.addLegalOp<func::FuncOp>();
|
||||
|
||||
|
|
|
@ -19,3 +19,39 @@ func.func @basic(%arg0: !torch.vtensor<[2,?],unk>) -> !torch.vtensor {
|
|||
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[2,?],unk> to !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @shape_calc_in_loop(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
||||
func.func @shape_calc_in_loop(%arg: !torch.vtensor<[2,?],unk>) -> !torch.vtensor<[2,?],unk> {
|
||||
%one = torch.constant.int 1
|
||||
// CHECK: %[[ONE:.*]] = torch.constant.int 1
|
||||
|
||||
%two = torch.constant.int 2
|
||||
// CHECK: %[[TWO:.*]] = torch.constant.int 2
|
||||
|
||||
%true = torch.constant.bool true
|
||||
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||
|
||||
torch.prim.Loop %one, %true, init() {
|
||||
// CHECK: torch.prim.Loop %[[ONE]], %[[TRUE]], init() {
|
||||
|
||||
^bb0(%in: !torch.int):
|
||||
%shape_calc = torch.shape.calculate {
|
||||
%tanh = torch.aten.tanh %arg : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
|
||||
torch.shape.calculate.yield %tanh : !torch.vtensor<[2,?],unk>
|
||||
} shapes {
|
||||
%size = torch.aten.size.int %arg, %one : !torch.vtensor<[2,?],unk>, !torch.int -> !torch.int
|
||||
%list = torch.prim.ListConstruct %two, %size : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
torch.shape.calculate.yield.shapes %list : !torch.list<int>
|
||||
} : !torch.vtensor<[2,?],unk>
|
||||
// CHECK: torch.aten.tanh %[[ARG]] : !torch.vtensor<[2,?],unk> -> !torch.vtensor<[2,?],unk>
|
||||
|
||||
torch.prim.Loop.condition %true, iter()
|
||||
// CHECK: torch.prim.Loop.condition %[[TRUE]], iter()
|
||||
} : (!torch.int, !torch.bool) -> ()
|
||||
|
||||
return %arg : !torch.vtensor<[2,?],unk>
|
||||
// CHECK: return %[[ARG]] : !torch.vtensor<[2,?],unk>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue