2022-06-24 09:34:39 +08:00
|
|
|
// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s
|
|
|
|
|
[MLIR][Torch] Canonicalize torch.from_i1 and torch.to_i1 (#3067)
When lowering `torch.aten.convolution`, it is expected that the
'transposed' argument is a torch.constant operation. In some cases, the
argument was a `from_i1` operation converting an `arith.constant`
operation into a torch.bool. This is not wrong semantically, but instead
of generalizing the legality of the `torch.aten.convolution` op, we
canonicalize `arith.constant` ops followed by `from_i1` ops to
`torch.bool` ops.
For example:
```
//===-------------------------------------------===//
Legalizing operation : 'torch.aten.convolution'(0x124705b90) {
%33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>
* Fold {
} -> FAILURE : unable to fold
* Pattern : 'torch.aten.convolution -> ()' {
** Failure : unimplemented: only constant transposed supported. <-- Resolved by this PR
} -> FAILURE : pattern failed to match
* Pattern : 'torch.aten.convolution -> ()' {
** Failure : not a supported Scalar to Tensor like op
} -> FAILURE : pattern failed to match
* Pattern : 'torch.aten.convolution -> ()' {
** Failure : not a supported elementwise op
} -> FAILURE : pattern failed to match
* Pattern : 'torch.aten.convolution -> ()' {
** Failure : not a supported reduce op
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
<stdin>:21:11: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal
%17 = torch.operator "onnx.Conv"(%arg0, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [5 : si64, 5 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>) -> !torch.vtensor<[1,10,24,24],f32>
^
<stdin>:21:11: note: see current operation: %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32>
```
Additionally, we require the canonicalization of `to_i1` operating on a
torch.constant bool to an `arith.constant ... : i1` for the e2e tests to
pass successfully.
2024-04-02 05:25:51 +08:00
|
|
|
// CHECK-LABEL: func.func @torch_c.from_i1() -> !torch.bool {
|
|
|
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
|
|
|
// CHECK: return %[[TRUE]] : !torch.bool
|
|
|
|
func.func @torch_c.from_i1() -> !torch.bool {
|
|
|
|
%c1_i1 = arith.constant true
|
|
|
|
%0 = torch_c.from_i1 %c1_i1
|
|
|
|
return %0 : !torch.bool
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.to_i1() -> i1 {
|
|
|
|
// CHECK: %[[C1_I1:.*]] = arith.constant true
|
|
|
|
// CHECK: return %[[C1_I1]] : i1
|
|
|
|
func.func @torch_c.to_i1() -> i1 {
|
|
|
|
%bool1 = torch.constant.bool true
|
|
|
|
%0 = torch_c.to_i1 %bool1
|
|
|
|
return %0 : i1
|
|
|
|
}
|
|
|
|
|
2022-06-24 09:34:39 +08:00
|
|
|
// CHECK-LABEL: func.func @torch_c.from_i64() -> !torch.int {
|
|
|
|
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
|
|
|
// CHECK: return %[[INT5]] : !torch.int
|
|
|
|
func.func @torch_c.from_i64() -> !torch.int {
|
|
|
|
%c5_i64 = arith.constant 5 : i64
|
|
|
|
%0 = torch_c.from_i64 %c5_i64
|
|
|
|
return %0 : !torch.int
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.to_i64() -> i64 {
|
|
|
|
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
|
|
|
|
// CHECK: return %[[C5_I64]] : i64
|
|
|
|
func.func @torch_c.to_i64() -> i64 {
|
|
|
|
%int5 = torch.constant.int 5
|
|
|
|
%0 = torch_c.to_i64 %int5
|
|
|
|
return %0 : i64
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.from_i64$to_i64() -> i64 {
|
|
|
|
// CHECK: %[[C5_I64:.*]] = arith.constant 5 : i64
|
|
|
|
// CHECK: return %[[C5_I64]] : i64
|
|
|
|
func.func @torch_c.from_i64$to_i64() -> i64 {
|
|
|
|
%c5_i64 = arith.constant 5 : i64
|
|
|
|
%0 = torch_c.from_i64 %c5_i64
|
|
|
|
%1 = torch_c.to_i64 %0
|
|
|
|
return %1 : i64
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.to_i64$from_i64() -> !torch.int {
|
|
|
|
// CHECK: %[[INT5:.*]] = torch.constant.int 5
|
|
|
|
// CHECK: return %[[INT5]] : !torch.int
|
|
|
|
func.func @torch_c.to_i64$from_i64() -> !torch.int {
|
|
|
|
%int5 = torch.constant.int 5
|
|
|
|
%0 = torch_c.to_i64 %int5
|
|
|
|
%1 = torch_c.from_i64 %0
|
|
|
|
return %1 : !torch.int
|
|
|
|
}
|
2022-08-22 09:49:39 +08:00
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.from_f64() -> !torch.float {
|
|
|
|
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
|
|
|
|
// CHECK: return %[[FLOAT5]] : !torch.float
|
|
|
|
func.func @torch_c.from_f64() -> !torch.float {
|
|
|
|
%c5_f64 = arith.constant 5.000000e+00 : f64
|
|
|
|
%0 = torch_c.from_f64 %c5_f64
|
|
|
|
return %0 : !torch.float
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.to_f64() -> f64 {
|
|
|
|
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
|
|
|
|
// CHECK: return %[[C5_f64]] : f64
|
|
|
|
func.func @torch_c.to_f64() -> f64 {
|
|
|
|
%float5 = torch.constant.float 5.000000e+00
|
|
|
|
%0 = torch_c.to_f64 %float5
|
|
|
|
return %0 : f64
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.from_f64$to_f64() -> f64 {
|
|
|
|
// CHECK: %[[C5_f64:.*]] = arith.constant 5.000000e+00 : f64
|
|
|
|
// CHECK: return %[[C5_f64]] : f64
|
|
|
|
func.func @torch_c.from_f64$to_f64() -> f64 {
|
|
|
|
%c5_f64 = arith.constant 5.000000e+00 : f64
|
|
|
|
%0 = torch_c.from_f64 %c5_f64
|
|
|
|
%1 = torch_c.to_f64 %0
|
|
|
|
return %1 : f64
|
|
|
|
}
|
|
|
|
|
|
|
|
// CHECK-LABEL: func.func @torch_c.to_f64$from_f64() -> !torch.float {
|
|
|
|
// CHECK: %[[FLOAT5:.*]] = torch.constant.float 5.000000e+00
|
|
|
|
// CHECK: return %[[FLOAT5]] : !torch.float
|
|
|
|
func.func @torch_c.to_f64$from_f64() -> !torch.float {
|
|
|
|
%float5 = torch.constant.float 5.000000e+00
|
|
|
|
%0 = torch_c.to_f64 %float5
|
|
|
|
%1 = torch_c.from_f64 %0
|
|
|
|
return %1 : !torch.float
|
|
|
|
}
|