diff --git a/lib/Conversion/TorchToMhlo/BasicOp.cpp b/lib/Conversion/TorchToMhlo/BasicOp.cpp index 1c33234b2..c3586b978 100644 --- a/lib/Conversion/TorchToMhlo/BasicOp.cpp +++ b/lib/Conversion/TorchToMhlo/BasicOp.cpp @@ -22,6 +22,7 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include #include @@ -628,6 +629,35 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // namespace +// Convert a Aten::GELU to HLO +// Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenGeluOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + Location loc = op.getLoc(); + Value input = adaptor.self(); + auto inputTy = input.getType().template dyn_cast(); + if (!inputTy) { + return op.emitError("only ranked tensor type is supported."); + } + + Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); + Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); + Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); + auto rsqrtTwo = rewriter.create(loc, two); + auto erfElement = rewriter.create(loc, input, rsqrtTwo); + auto erf = rewriter.create(loc, erfElement); + auto erfAdd = rewriter.create(loc, erf, one); + auto halfMul = rewriter.create(loc, erfAdd, half); + rewriter.replaceOpWithNewOp(op, input, halfMul); + return success(); +} +} // namespace + + // AtenErfOp namespace { template <> @@ -984,6 +1014,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenContiguousOp); INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenBatchNormOp); diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h index 35bca4019..75cca68bd 100644 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ b/lib/Conversion/TorchToMhlo/PopulatePatterns.h @@ -25,6 +25,7 @@ void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target); + } // namespace torch_to_mhlo } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp index 1907a10c5..db66c0549 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp @@ -23,6 +23,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" using namespace mlir; using namespace mlir::torch; diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index e11e40fce..ee7da07e1 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -1,41 +1,5 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func.func @torch.aten.tanh$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.exp$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} // ----- @@ -52,18 +16,6 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt return %0 : !torch.vtensor<[?,?],f32> } -// ----- - -// CHECK-LABEL: func.func @torch.aten.neg$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.negate %[[VAL_1]] : tensor -// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} // ----- @@ -87,329 +39,6 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { return %0 : !torch.vtensor<[2],si64> } - -// ----- - -// CHECK-LABEL: func.func @torch.aten.addscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %int2 = torch.constant.int 2 - %0 = torch.aten.add.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.addtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int2 = torch.constant.int 2 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.addtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> -func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { - %int1 = torch.constant.int 1 - %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> - return %0 : !torch.vtensor<[?,?],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.subscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %int2 = torch.constant.int 2 - %0 = torch.aten.sub.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.subtensor$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int2 = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int2 = torch.constant.int 2 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.subtensor$promote( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> -func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { - %int1 = torch.constant.int 1 - %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> - return %0 : !torch.vtensor<[?,?],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_multiply %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.multensor$basic( -// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_4:.*]] = chlo.broadcast_multiply %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.divscalar$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int9 = torch.constant.int 9 -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor -// CHECK: %[[VAL_3:.*]] = chlo.broadcast_divide %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %int9 = torch.constant.int 9 - %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.divtensor$basic( -// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VLA_4:.*]] = chlo.broadcast_divide %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor -// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.gt.scalar( -// CHECK-SAME: %arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %int3 = torch.constant.int 3 -// CHECK: %1 = mhlo.constant dense<3.000000e+00> : tensor -// CHECK: %2 = chlo.broadcast_compare %0, %1 {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor -// CHECK: %3 = torch_c.from_builtin_tensor %2 : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %3 : !torch.vtensor<[?,?],i1> -func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { - %int3 = torch.constant.int 3 - %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.gt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.lt.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.eq.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.ne.tensor( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> -// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> -func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { - %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> - return %0 : !torch.vtensor<[?,?],i1> -} - // ----- // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[],si64> { @@ -451,24 +80,6 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v return %0 : !torch.vtensor<[?,?,?],f32> } -// ----- - -// CHECK-LABEL: func.func @torch.aten.permute$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> -func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> - return %1 : !torch.vtensor<[64,4],f32> -} // ----- @@ -518,20 +129,6 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?], // ----- -// CHECK-LABEL: func.func @torch.aten.relu( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.batch_norm$training( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir new file mode 100644 index 000000000..643c44b11 --- /dev/null +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -0,0 +1,431 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gelu( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor +// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor +// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor +// CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor +// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor +// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor +// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor +// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %str = torch.constant.str "none" + %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + + +// CHECK-LABEL: func.func @torch.aten.tanh$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.tanh %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.tanh %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.log$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.log %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.log %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.exp$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.exponential %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.exp %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.neg$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = mhlo.negate %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = torch_c.from_builtin_tensor %[[VAL_2]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_3]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addscalar$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = chlo.broadcast_add %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.addtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_add %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int1 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subscalar$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_multiply %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subscalar$alpha(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Scalar %arg0, %int9, %int2 : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$alpha( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int2 = torch.constant.int 2 +// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_multiply %[[VAL_3]], %[[VAL_4]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = chlo.broadcast_subtract %[[VAL_2]], %[[VAL_5]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.subtensor$promote( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si64> -> tensor +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = mhlo.convert(%[[VAL_2]]) : (tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = chlo.broadcast_subtract %[[VAL_4]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si64> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si64> +func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?,?],si64> { + %int1 = torch.constant.int 1 + %0 = torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si64>, !torch.int -> !torch.vtensor<[?,?],si64> + return %0 : !torch.vtensor<[?,?],si64> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mulscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_multiply %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mulscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.mul.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.multensor$basic( +// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_4:.*]] = chlo.broadcast_multiply %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divscalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int9 = torch.constant.int 9 +// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<9.000000e+00> : tensor +// CHECK: %[[VAL_3:.*]] = chlo.broadcast_divide %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int9 = torch.constant.int 9 + %0 = torch.aten.div.Scalar %arg0, %int9 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.divtensor$basic( +// CHECK-SAME: %[[VLA_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VLA_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VLA_2:.*]] = torch_c.to_builtin_tensor %[[VLA_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_3:.*]] = torch_c.to_builtin_tensor %[[VLA_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VLA_4:.*]] = chlo.broadcast_divide %[[VLA_2]], %[[VLA_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VLA_5:.*]] = torch_c.from_builtin_tensor %[[VLA_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VLA_5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.scalar( +// CHECK-SAME: %arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %int3 = torch.constant.int 3 +// CHECK: %1 = mhlo.constant dense<3.000000e+00> : tensor +// CHECK: %2 = chlo.broadcast_compare %0, %1 {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor +// CHECK: %3 = torch_c.from_builtin_tensor %2 : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %3 : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.scalar(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { + %int3 = torch.constant.int 3 + %0 = torch.aten.gt.Scalar %arg0, %int3 : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.gt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.gt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.gt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.lt.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.lt.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.lt.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.eq.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.eq.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.ne.tensor( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[64],f32> -> tensor<64xf32> +// CHECK: %[[VAL_4:.*]] = chlo.broadcast_compare %[[VAL_2]], %[[VAL_3]] {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor<64xf32>) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[64],f32> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.permute$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,64],f32> -> tensor<4x64xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_5:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> +// CHECK: return %[[VAL_6]] : !torch.vtensor<[64,4],f32> +func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int1, %int0 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[4,64],f32>, !torch.list -> !torch.vtensor<[64,4],f32> + return %1 : !torch.vtensor<[64,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.relu( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = mhlo.maximum %[[VAL_1]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %0 = torch.aten.relu %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}