diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 474e85924..d65a8753f 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -479,7 +479,8 @@ TOSA_PASS_SET = { "ToDtypeBoolLayoutNoneStaticModule_basic", "ToCopyBoolDTypeStaticModule_basic", "HardTanhIntModule_basic", - "AtenRoundIntModule_basic" + "AtenRoundIntModule_basic", + "MseLossNoReductionModule_basic", } LTC_XFAIL_SET = { diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 72afccad5..55e29441f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4693,6 +4693,31 @@ def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ }]; } +def Torch_AtenMseLossOp : Torch_Op<"aten.mse_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMseLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMseLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0c5d9a56b..9e003df0e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2845,6 +2845,57 @@ public: }; } // namespace +namespace { +class DecomposeAtenMseLossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMseLossOp op, + PatternRewriter &rewriter) const override { + + // The `reduction` arg would have only three valid values. + // 0 means no reduction. + // 1 means mean reduction. + // 2 means sum reduction. + int64_t reductionType; + if (!matchPattern(op.reduction(), m_TorchConstantInt(&reductionType))) + return rewriter.notifyMatchFailure( + op, "Expected a constant integer value for reduction"); + + Location loc = op.getLoc(); + BaseTensorType resultType = op.getType().cast(); + BaseTensorType inputType = op.self().getType().cast(); + if (!inputType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "Expected the input tensor to have sizes"); + BaseTensorType subType = + inputType + .getWithSizesAndDtype(llvm::makeArrayRef(inputType.getSizes()), + resultType.getDtype()) + .cast(); + + Value sub = createTensorSub(rewriter, loc, subType, op.self(), op.target()); + Value result = rewriter.create(loc, subType, sub); + if (reductionType == torch_upstream::Reduction::None) { + rewriter.replaceOp(op, result); + return success(); + } + Value cstFalse = rewriter.create(loc, false); + Value cstNone = rewriter.create(loc); + if (reductionType == torch_upstream::Reduction::Mean) + result = rewriter.create(loc, resultType, result, + /*dim=*/cstNone, + /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + else + result = rewriter.create( + loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -3040,6 +3091,8 @@ public: target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); for (std::string opName : legalOps) { target.addLegalOp(OperationName(opName, context)); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 63ae3d0e4..5571bd5de 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -756,7 +756,8 @@ void TypeAnalysis::visitOperation(Operation *op, // Promote the two dtypes assuming non-zero rank. if (isa(op)) { + AtenConvolutionOverrideableOp, AtenConvTranspose2dInputOp, + AtenMseLossOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultTypeAssumingNonZeroRank( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index b18f6c4c8..1e2268422 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6743,6 +6743,18 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.mse_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.native_layer_norm\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float) -> !torch.tuple, list, list> {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index f566d5694..2deaf4abf 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -1047,6 +1047,11 @@ def aten〇nll_loss_forward(self: List[int], target: List[int], weight: Optional def aten〇nll_loss_backward(grad_output: List[int], self: List[int], target: List[int], weight: Optional[List[int]], reduction: int, ignore_index: int, total_weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇mse_loss(self: List[int], target: List[int], reduction: int = 1) -> List[int]: + if reduction == 0: + return upstream_shape_functions.unary(self) + return [] + @check_shape_function([ Invocation(TensorOfShape(2, 5, 2, 2, 3), [2, 2, 3], None, None, 1e-6), # Basic case. ]) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 29176cda5..9f3a601b8 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -406,6 +406,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") + emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/reduction.py b/python/torch_mlir_e2e_test/test_suite/reduction.py index b28d78a12..8c9225869 100644 --- a/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -601,3 +601,61 @@ class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): @register_test_case(module_factory=lambda: ReduceFrobeniusNormKeepDimModule()) def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class MseLossNoReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ([-1 , -1], torch.float32, True), + ]) + + def forward(self, x, y): + return torch.ops.aten.mse_loss(x, y, reduction=0) + +@register_test_case(module_factory=lambda: MseLossNoReductionModule()) +def MseLossNoReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +class MseLossMeanReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ([-1 , -1], torch.float32, True), + ]) + + def forward(self, x, y): + return torch.ops.aten.mse_loss(x, y, reduction=1) + +@register_test_case(module_factory=lambda: MseLossMeanReductionModule()) +def MseLossMeanReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1 , -1], torch.float32, True), + ([-1 , -1], torch.float64, True), + ]) + + def forward(self, x, y): + return torch.ops.aten.mse_loss(x, y, reduction=2) + +@register_test_case(module_factory=lambda: MseLossSumReductionWithDifferentElemTypeModule()) +def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4).to(torch.float64)) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index b47b6ecb5..2e09d97d2 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -991,3 +991,59 @@ func.func @torch.aten.roll(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int, %2 = torch.aten.roll %arg0, %0, %1 : !torch.vtensor<[?,?],f32>, !torch.list, !torch.list -> !torch.vtensor<[?,?],f32> return %2 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mse_loss$no_reduction( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[REDUCTION:.*]] = torch.constant.int 0 +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> +// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mse_loss$no_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.aten.mse_loss %arg0, %arg1, %int0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mse_loss$mean_reduction( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[REDUCTION:.*]] = torch.constant.int 1 +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> +// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: %[[NUMEL:.*]] = torch.aten.numel %[[SUB_SQUARE]] : !torch.vtensor<[?,?],f32> -> !torch.int +// CHECK: %[[SUB_SQUARE_MEAN:.*]] = torch.aten.div.Scalar %[[SUB_SQUARE_SUM]], %[[NUMEL]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[SUB_SQUARE_MEAN]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mse_loss$mean_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int1 = torch.constant.int 1 + %0 = torch.aten.mse_loss %arg0, %arg1, %int1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.mse_loss$sum_reduction( +// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[TARGET:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[REDUCTION:.*]] = torch.constant.int 2 +// CHECK: %[[ALPHA:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[SUB:.*]] = torch.aten.sub.Tensor %[[SELF]], %[[TARGET]], %[[ALPHA]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> +// CHECK: %[[SUB_SQUARE:.*]] = torch.aten.mul.Tensor %[[SUB]], %[[SUB]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[SUB_SQUARE_SUM:.*]] = torch.aten.sum.dim_IntList %[[SUB_SQUARE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[?,?],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[SUB_SQUARE_SUM]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.mse_loss$sum_reduction(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int2 = torch.constant.int 2 + %0 = torch.aten.mse_loss %arg0, %arg1, %int2 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}