Add folders for torch.aten.gt.int / torch.aten.ne.int

This fixes a "regression" on ResNet where we weren't folding away all
the control flow. For now, our policy is to "optimize hard enough" to
make that control flow go away, because we don't yet have a way to lower
to the backend the stuff guarded by the control flow (RaiseException,
string operations, etc.).

It remains to be seen how much optimization we decide to do at this
level in the fullness of time -- the torch op set is not particularly
well-designed (at least not idiomatically for MLIR) for general
optimization. Ideally, with really good backend support for various
features, all the heavy optimization will happen at that layer on `std`
ops and `scf` control flow. But I have a suspicion we might end up
needing more optimization earlier in the pipeline.
pull/228/head
Sean Silva 2021-06-16 11:05:08 -07:00
parent 8860b5c55d
commit 224afb186e
5 changed files with 63 additions and 12 deletions

View File

@ -444,8 +444,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
# Primitive ops # Primitive ops
emit("aten::gt.int : (int, int) -> (bool)") emit("aten::gt.int : (int, int) -> (bool)", has_folder=True)
emit("aten::ne.int : (int, int) -> (bool)") emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
emit("aten::add.int : (int, int) -> (int)") emit("aten::add.int : (int, int) -> (int)")
emit("aten::mul.int : (int, int) -> (int)") emit("aten::mul.int : (int, int) -> (int)")
emit("aten::add.float_int : (float, int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)")

View File

@ -268,6 +268,7 @@ def Torch_AtenGtIntOp : Torch_Op<"aten.gt.int", [
Torch_BoolType:$result Torch_BoolType:$result
); );
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let hasFolder = 1;
} }
def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [ def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [
@ -283,6 +284,7 @@ def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [
Torch_BoolType:$result Torch_BoolType:$result
); );
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)"; let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
let hasFolder = 1;
} }
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [

View File

@ -411,6 +411,36 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
}); });
} }
//===----------------------------------------------------------------------===//
// AtenGtIntOp
//===----------------------------------------------------------------------===//
static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
return IntegerAttr::get(IntegerType::get(context, 1),
static_cast<int64_t>(value));
}
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
if (lhs && rhs) {
if (lhs.getValue().getSExtValue() > rhs.getValue().getSExtValue())
return getI1IntegerAttr(getContext(), true);
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// AtenNeIntOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
// `torch.aten.ne.int %x, %x` -> `false`
if (getOperand(0) == getOperand(1))
return getI1IntegerAttr(getContext(), false);
return nullptr;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TensorOp // TensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -129,6 +129,16 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// Clean up a few stray conversion remnants. // Clean up a few stray conversion remnants.
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass()); pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
if (options.optimize) {
// All the type refinement we've done above has exposed new information
// that allows folding away more stuff.
// OPT-ONLY: Right now we rely on this to eliminate certain
// branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
}
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
// Lowering ops and the !torch.vtensor type. // Lowering ops and the !torch.vtensor type.
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
@ -141,16 +151,6 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
// TODO: Improve torch op canonicalizations. // TODO: Improve torch op canonicalizations.
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass()); pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
if (options.optimize) {
// RefineTypes has exposed new type information that allows folding away
// more stuff.
// OPT-ONLY: Right now we rely on this to eliminate certain
// branches that guard unreachable code that backends can't handle yet, such
// as lists, RaiseException, unimplemented aten ops, and
// only-used-in-training operations on `torch.global_slot`'s.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
}
// Lower to linalg + guards which is the input to codegen backends. // Lower to linalg + guards which is the input to codegen backends.
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass()); pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());

View File

@ -30,6 +30,25 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.l
return %0 : !torch.list<i64> return %0 : !torch.list<i64>
} }
// CHECK-LABEL: func @torch.aten.gt.int$evaluate() -> !torch.bool {
// CHECK-NEXT: %[[T:.*]] = torch.constant.bool true
// CHECK-NEXT: return %[[T]] : !torch.bool
func @torch.aten.gt.int$evaluate() -> !torch.bool {
%int2 = torch.constant.int 2
%int4 = torch.constant.int 4
%0 = torch.aten.gt.int %int4, %int2 : i64, i64 -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.ne.int$same_value(
// CHECK-SAME: %{{.*}}: i64) -> !torch.bool {
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
// CHECK-NEXT: return %[[F]] : !torch.bool
func @torch.aten.ne.int$same_value(%arg0: i64) -> !torch.bool {
%0 = torch.aten.ne.int %arg0, %arg0 : i64, i64 -> !torch.bool
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.aten.len.t$of_size( // CHECK-LABEL: func @torch.aten.len.t$of_size(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 { // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 {
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64 // CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64