mirror of https://github.com/llvm/torch-mlir
Add folders for torch.aten.gt.int / torch.aten.ne.int
This fixes a "regression" on ResNet where we weren't folding away all the control flow. For now, our policy is to "optimize hard enough" to make that control flow go away, because we don't yet have a way to lower to the backend the stuff guarded by the control flow (RaiseException, string operations, etc.). It remains to be seen how much optimization we decide to do at this level in the fullness of time -- the torch op set is not particularly well-designed (at least not idiomatically for MLIR) for general optimization. Ideally, with really good backend support for various features, all the heavy optimization will happen at that layer on `std` ops and `scf` control flow. But I have a suspicion we might end up needing more optimization earlier in the pipeline.pull/228/head
parent
8860b5c55d
commit
224afb186e
|
@ -444,8 +444,8 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||||
|
|
||||||
# Primitive ops
|
# Primitive ops
|
||||||
emit("aten::gt.int : (int, int) -> (bool)")
|
emit("aten::gt.int : (int, int) -> (bool)", has_folder=True)
|
||||||
emit("aten::ne.int : (int, int) -> (bool)")
|
emit("aten::ne.int : (int, int) -> (bool)", has_folder=True)
|
||||||
emit("aten::add.int : (int, int) -> (int)")
|
emit("aten::add.int : (int, int) -> (int)")
|
||||||
emit("aten::mul.int : (int, int) -> (int)")
|
emit("aten::mul.int : (int, int) -> (int)")
|
||||||
emit("aten::add.float_int : (float, int) -> (float)")
|
emit("aten::add.float_int : (float, int) -> (float)")
|
||||||
|
|
|
@ -268,6 +268,7 @@ def Torch_AtenGtIntOp : Torch_Op<"aten.gt.int", [
|
||||||
Torch_BoolType:$result
|
Torch_BoolType:$result
|
||||||
);
|
);
|
||||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [
|
def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [
|
||||||
|
@ -283,6 +284,7 @@ def Torch_AtenNeIntOp : Torch_Op<"aten.ne.int", [
|
||||||
Torch_BoolType:$result
|
Torch_BoolType:$result
|
||||||
);
|
);
|
||||||
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
let assemblyFormat = "$a `,` $b attr-dict `:` type($a) `,` type($b) `->` type($result)";
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
|
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
|
||||||
|
|
|
@ -411,6 +411,36 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenGtIntOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static IntegerAttr getI1IntegerAttr(MLIRContext *context, bool value) {
|
||||||
|
return IntegerAttr::get(IntegerType::get(context, 1),
|
||||||
|
static_cast<int64_t>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult AtenGtIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
|
||||||
|
auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
|
||||||
|
if (lhs && rhs) {
|
||||||
|
if (lhs.getValue().getSExtValue() > rhs.getValue().getSExtValue())
|
||||||
|
return getI1IntegerAttr(getContext(), true);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenNeIntOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
// `torch.aten.ne.int %x, %x` -> `false`
|
||||||
|
if (getOperand(0) == getOperand(1))
|
||||||
|
return getI1IntegerAttr(getContext(), false);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// TensorOp
|
// TensorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -129,6 +129,16 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
// Clean up a few stray conversion remnants.
|
// Clean up a few stray conversion remnants.
|
||||||
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
pm.addNestedPass<FuncOp>(Torch::createMaximizeValueSemanticsPass());
|
||||||
|
|
||||||
|
if (options.optimize) {
|
||||||
|
// All the type refinement we've done above has exposed new information
|
||||||
|
// that allows folding away more stuff.
|
||||||
|
// OPT-ONLY: Right now we rely on this to eliminate certain
|
||||||
|
// branches that guard unreachable code that backends can't handle yet, such
|
||||||
|
// as lists, RaiseException, unimplemented aten ops, and
|
||||||
|
// only-used-in-training operations on `torch.global_slot`'s.
|
||||||
|
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||||
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Lowering ops and the !torch.vtensor type.
|
// Lowering ops and the !torch.vtensor type.
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -141,16 +151,6 @@ void mlir::NPCOMP::Torch::createLowerToNpcompBackendPipeline(
|
||||||
// TODO: Improve torch op canonicalizations.
|
// TODO: Improve torch op canonicalizations.
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
pm.addNestedPass<FuncOp>(createConvertTorchToStdPass());
|
||||||
|
|
||||||
if (options.optimize) {
|
|
||||||
// RefineTypes has exposed new type information that allows folding away
|
|
||||||
// more stuff.
|
|
||||||
// OPT-ONLY: Right now we rely on this to eliminate certain
|
|
||||||
// branches that guard unreachable code that backends can't handle yet, such
|
|
||||||
// as lists, RaiseException, unimplemented aten ops, and
|
|
||||||
// only-used-in-training operations on `torch.global_slot`'s.
|
|
||||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lower to linalg + guards which is the input to codegen backends.
|
// Lower to linalg + guards which is the input to codegen backends.
|
||||||
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
pm.addNestedPass<FuncOp>(createConvertTorchToLinalgPass());
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,25 @@ func @torch.aten.size$unknown_size(%arg0: !torch.vtensor<[?,3],f32>) -> !torch.l
|
||||||
return %0 : !torch.list<i64>
|
return %0 : !torch.list<i64>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.gt.int$evaluate() -> !torch.bool {
|
||||||
|
// CHECK-NEXT: %[[T:.*]] = torch.constant.bool true
|
||||||
|
// CHECK-NEXT: return %[[T]] : !torch.bool
|
||||||
|
func @torch.aten.gt.int$evaluate() -> !torch.bool {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%int4 = torch.constant.int 4
|
||||||
|
%0 = torch.aten.gt.int %int4, %int2 : i64, i64 -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.ne.int$same_value(
|
||||||
|
// CHECK-SAME: %{{.*}}: i64) -> !torch.bool {
|
||||||
|
// CHECK-NEXT: %[[F:.*]] = torch.constant.bool false
|
||||||
|
// CHECK-NEXT: return %[[F]] : !torch.bool
|
||||||
|
func @torch.aten.ne.int$same_value(%arg0: i64) -> !torch.bool {
|
||||||
|
%0 = torch.aten.ne.int %arg0, %arg0 : i64, i64 -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.len.t$of_size(
|
// CHECK-LABEL: func @torch.aten.len.t$of_size(
|
||||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 {
|
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<*,f32>) -> i64 {
|
||||||
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64
|
// CHECK: %[[DIM:.*]] = torch.aten.dim %[[ARG]] : !torch.vtensor<*,f32> -> i64
|
||||||
|
|
Loading…
Reference in New Issue