// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @torch.aten.mm$basic( // CHECK-SAME: %[[LHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[RHS_VTENSOR:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { // CHECK: %[[LHS:.*]] = torch_c.to_builtin_tensor %[[LHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[RHS:.*]] = torch_c.to_builtin_tensor %[[RHS_VTENSOR]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[LHS_DIM_0:.*]] = tensor.dim %[[LHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[LHS_DIM_1:.*]] = tensor.dim %[[LHS]], %[[C1]] : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[RHS_DIM_0:.*]] = tensor.dim %[[RHS]], %[[C0]] : tensor // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[RHS_DIM_1:.*]] = tensor.dim %[[RHS]], %[[C1]] : tensor // CHECK: %[[EQ:.*]] = arith.cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index // CHECK: assert %[[EQ]], "mismatching contracting dimension for torch.aten.mm" // CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor // CHECK: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[CF0]], %[[INIT_TENSOR]]) : f32, tensor -> tensor // CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor, tensor) outs(%[[ZEROFILL]] : tensor) -> tensor // CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor to tensor // CHECK: %[[RESULT_VTENSOR:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[?,2],f32> // CHECK: return %[[RESULT_VTENSOR]] : !torch.vtensor<[?,2],f32> func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> return %0 : !torch.vtensor<[?,2],f32> } // ----- // If the operands are missing dtype, we cannot lower it. func @torch.aten.mm$no_convert$missing_dtype(%arg0: !torch.vtensor, %arg1: !torch.vtensor) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return %0 : !torch.vtensor } // ----- // Correctly handle the case that operands are statically the wrong rank // (rank 1 vs rank 2 expected for matmul.) func @torch.aten.mm$no_convert$wrong_rank(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { // expected-error@+1 {{failed to legalize}} %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } // ----- // If the result is missing dtype, we cannot lower it. func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor { // expected-error@+1 {{failed to legalize}} %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor return %0 : !torch.vtensor } // ----- // CHECK-LABEL: func @integer_extract // CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int { // CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor // CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor // CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]] // CHECK: return %[[RET]] : !torch.int func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int { %0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int return %0 : !torch.int } // ----- // CHECK: func @torch.prim.NumToTensor.Scalar$basic(%[[IN:.*]]: !torch.int) -> !torch.vtensor<[],si64> { // CHECK: %[[INI64:.*]] = torch_c.to_i64 %[[IN]] // CHECK: %[[NEWVEC:.*]] = linalg.init_tensor [] : tensor // CHECK: %[[FILLVEC:.*]] = linalg.fill(%[[INI64]], %[[NEWVEC]]) : i64, tensor -> tensor // CHECK: %[[OUTVEC:.*]] = torch_c.from_builtin_tensor %[[FILLVEC]] : tensor -> !torch.vtensor<[],si64> // CHECK: return %[[OUTVEC]] : !torch.vtensor<[],si64> func @torch.prim.NumToTensor.Scalar$basic(%arg0: !torch.int) -> !torch.vtensor<[],si64> { %0 = torch.prim.NumToTensor.Scalar %arg0 : !torch.int -> !torch.vtensor<[],si64> return %0 : !torch.vtensor<[],si64> }