mirror of https://github.com/llvm/torch-mlir
3c33dbd987
When lowering `torch.aten.convolution`, it is expected that the 'transposed' argument is a torch.constant operation. In some cases, the argument was a `from_i1` operation converting an `arith.constant` operation into a torch.bool. This is not wrong semantically, but instead of generalizing the legality of the `torch.aten.convolution` op, we canonicalize `arith.constant` ops followed by `from_i1` ops to `torch.bool` ops. For example: ``` //===-------------------------------------------===// Legalizing operation : 'torch.aten.convolution'(0x124705b90) { %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32> * Fold { } -> FAILURE : unable to fold * Pattern : 'torch.aten.convolution -> ()' { ** Failure : unimplemented: only constant transposed supported. <-- Resolved by this PR } -> FAILURE : pattern failed to match * Pattern : 'torch.aten.convolution -> ()' { ** Failure : not a supported Scalar to Tensor like op } -> FAILURE : pattern failed to match * Pattern : 'torch.aten.convolution -> ()' { ** Failure : not a supported elementwise op } -> FAILURE : pattern failed to match * Pattern : 'torch.aten.convolution -> ()' { ** Failure : not a supported reduce op } -> FAILURE : pattern failed to match } -> FAILURE : no matched legalization pattern //===-------------------------------------------===// <stdin>:21:11: error: failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal %17 = torch.operator "onnx.Conv"(%arg0, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [5 : si64, 5 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>) -> !torch.vtensor<[1,10,24,24],f32> ^ <stdin>:21:11: note: see current operation: %33 = "torch.aten.convolution"(%arg0, %20, %21, %31, %29, %30, %19, %32, %0) : (!torch.vtensor<[1,1,28,28],f32>, !torch.vtensor<[10,1,5,5],f32>, !torch.vtensor<[10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int) -> !torch.vtensor<[1,10,24,24],f32> ``` Additionally, we require the canonicalization of `to_i1` operating on a torch.constant bool to an `arith.constant ... : i1` for the e2e tests to pass successfully. |
||
---|---|---|
.. | ||
TMTensor | ||
Torch | ||
TorchConversion |