From 224afb186e43c8d3ca73a1ba8fa7b0e00f53d426 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 16 Jun 2021 11:05:08 -0700 Subject: [PATCH] Add folders for torch.aten.gt.int / torch.aten.ne.int This fixes a "regression" on ResNet where we weren't folding away all the control flow. For now, our policy is to "optimize hard enough" to make that control flow go away, because we don't yet have a way to lower to the backend the stuff guarded by the control flow (RaiseException, string operations, etc.). It remains to be seen how much optimization we decide to do at this level in the fullness of time -- the torch op set is not particularly well-designed (at least not idiomatically for MLIR) for general optimization. Ideally, with really good backend support for various features, all the heavy optimization will happen at that layer on `std` ops and `scf` control flow. But I have a suspicion we might end up needing more optimization earlier in the pipeline. --- .../torch_mlir_utils/codegen/torch_ods_gen.py | 4 +-- .../Dialect/Torch/IR/GeneratedAtenOps.td | 2 ++ lib/Dialect/Torch/IR/TorchOps.cpp | 30 +++++++++++++++++++ lib/Dialect/Torch/Transforms/Passes.cpp | 20 ++++++------- test/Dialect/Torch/canonicalize.mlir | 19 ++++++++++++ 5 files changed, 63 insertions(+), 12 deletions(-) diff --git a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py index 25c80b6c7..2624df8ec 100644 --- a/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py +++ b/frontends/pytorch/python/torch_mlir_utils/codegen/torch_ods_gen.py @@ -444,8 +444,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) # Primitive ops - emit("aten::gt.int : (int, int) -> (bool)") - emit("aten::ne.int : (int, int) -> (bool)") + emit("aten::gt.int : (int, int) -> (bool)", has_folder=True) + emit("aten::ne.int : (int, int) -> (bool)", has_folder=True) emit("aten::add.int : (int, int) -> (int)") emit("aten::mul.int : (int, int) -> (int)") emit("aten::add.float_int : (float, int) -> (float)") diff --git a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td index 0eb5d9221..85be685e0 100644 --- a/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/npcomp/Dialect/Torch/IR/GeneratedAtenOps.td @@ -268,6 +268,7 @@ def Torch_AtenGtIntOp : Torch_Op<"aten.gt.int", [ Torch_BoolType:$result ); let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; } def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [ @@ -283,6 +284,7 @@ def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [ Torch_BoolType:$result ); let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; + let hasFolder = 1; } def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 10521892f..358a64289 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -411,6 +411,36 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenGtIntOp +//===----------------------------------------------------------------------===// + +static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) { + return IntegerAttr::get(IntegerType::get(context, 1), + static_cast(value)); +} + +OpFoldResult AtenGtIntOp::fold(ArrayRef operands) { + auto lhs = operands[0].dyn_cast_or_null(); + auto rhs = operands[1].dyn_cast_or_null(); + if (lhs && rhs) { + if (lhs.getValue().getSExtValue() > rhs.getValue().getSExtValue()) + return getI1IntegerAttr(getContext(), true); + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenNeIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNeIntOp::fold(ArrayRef operands) { + // `torch.aten.ne.int %x, %x` -> `false` + if (getOperand(0) == getOperand(1)) + return getI1IntegerAttr(getContext(), false); + return nullptr; +} + //===----------------------------------------------------------------------===// // TensorOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/Passes.cpp b/lib/Dialect/Torch/Transforms/Passes.cpp index 9ad9ad779..0dd4fb6e1 100644 --- a/lib/Dialect/Torch/Transforms/Passes.cpp +++ b/lib/Dialect/Torch/Transforms/Passes.cpp @@ -129,6 +129,16 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( // Clean up a few stray conversion remnants. pm.addNestedPass(Torch::createMaximizeValueSemanticsPass()); + if (options.optimize) { + // All the type refinement we've done above has exposed new information + // that allows folding away more stuff. + // OPT-ONLY: Right now we rely on this to eliminate certain + // branches that guard unreachable code that backends can't handle yet, such + // as lists, RaiseException, unimplemented aten ops, and + // only-used-in-training operations on `torch.global_slot`'s. + pm.addNestedPass(createCanonicalizerPass()); + } + //===--------------------------------------------------------------------===// // Lowering ops and the !torch.vtensor type. //===--------------------------------------------------------------------===// @@ -141,16 +151,6 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline( // TODO: Improve torch op canonicalizations. pm.addNestedPass(createConvertTorchToStdPass()); - if (options.optimize) { - // RefineTypes has exposed new type information that allows folding away - // more stuff. - // OPT-ONLY: Right now we rely on this to eliminate certain - // branches that guard unreachable code that backends can't handle yet, such - // as lists, RaiseException, unimplemented aten ops, and - // only-used-in-training operations on `torch.global_slot`'s. - pm.addNestedPass(createCanonicalizerPass()); - } - // Lower to linalg + guards which is the input to codegen backends. pm.addNestedPass(createConvertTorchToLinalgPass()); diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 3c4aa0a08..c94f97398 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -30,6 +30,25 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.l return %0 : !torch.list } +// CHECK-LABEL: func @torch.aten.gt.int$evaluate() -> !torch.bool { +// CHECK-NEXT: %[[T:.*]] = torch.constant.bool true +// CHECK-NEXT: return %[[T]] : !torch.bool +func @torch.aten.gt.int$evaluate() -> !torch.bool { + %int2 = torch.constant.int 2 + %int4 = torch.constant.int 4 + %0 = torch.aten.gt.int %int4, %int2 : i64, i64 -> !torch.bool + return %0 : !torch.bool +} + +// CHECK-LABEL: func @torch.aten.ne.int$same_value( +// CHECK-SAME: %{{.*}}: i64) -> !torch.bool { +// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false +// CHECK-NEXT: return %[[F]] : !torch.bool +func @torch.aten.ne.int$same_value(%arg0: i64) -> !torch.bool { + %0 = torch.aten.ne.int %arg0, %arg0 : i64, i64 -> !torch.bool + return %0 : !torch.bool +} + // CHECK-LABEL: func @torch.aten.len.t$of_size( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 { // CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64