diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f169993a1..28764009a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9383,6 +9383,31 @@ def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [ }]; } +def Torch_AtenL1LossOp : Torch_Op<"aten.l1_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenL1LossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenL1LossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -16923,6 +16948,29 @@ def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [ let hasVerifier = 1; } +def Torch_AtenDeg2radOp : Torch_Op<"aten.deg2rad", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::deg2rad : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDeg2radOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDeg2radOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index c4c3a874f..d6ba57a08 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1143,6 +1143,49 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +// AtenLogitOp +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + + Value self = adaptor.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy) { + return op.emitError("only ranked tensor type is supported."); + } + + auto outTy = cast(getTypeConverter()->convertType(op.getType())); + self = hlo::promoteType(rewriter, op.getLoc(), self, outTy.getElementType()); + + selfTy = dyn_cast(self.getType()); + + Value eps = adaptor.getEps(); + auto epsTy = eps.getType(); + Value newSelf; + if (!isa(epsTy)) { + auto epsTensor = hlo::scalarToStablehloTensor(rewriter, op, eps, + selfTy.getElementType()); + Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor); + auto max = + rewriter.create(loc, oneEpsTensor, epsTensor); + newSelf = rewriter.create(loc, epsTensor, self, max); + } else { + newSelf = self; + } + + Value one = hlo::getConstantLike(rewriter, loc, 1.0, self); + Value zi1 = rewriter.create(loc, one, newSelf); + Value newZi = rewriter.create(loc, newSelf, zi1); + + Value log = rewriter.create(loc, outTy, newZi); + + rewriter.replaceOp(op, log); + + return success(); +} + // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2248,6 +2291,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenLog10Op); + INSERT_ATENOP_PATTERN(AtenLogitOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 560b6a821..1cc02a48f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10465,6 +10465,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.deg2rad\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -10485,6 +10489,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.l1_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.list) {\n" +" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" } else {\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" torch.prim.If.yield %2 : !torch.list\n" +" }\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" " return %0 : !torch.list\n" @@ -13864,6 +13880,24 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.l1_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" +" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -15918,6 +15952,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.deg2rad\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 445a354d4..2f276b1a2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1334,6 +1334,44 @@ public: }; } // namespace +namespace { +class DecomposeAtenDeg2radOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenDeg2radOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.getDtype()) { + return rewriter.notifyMatchFailure(op, "requires tensor types input."); + } + + auto outTy = dyn_cast(op.getType()); + if (!outTy || !outTy.getDtype()) { + return rewriter.notifyMatchFailure( + op, "requires output is a tensor with dtype."); + } + + if (selfTy.getDtype() != outTy.getDtype()) { + self = convertTensorToDtype(rewriter, loc, self, outTy.getDtype()); + } + + Value pi = + rewriter.create(loc, rewriter.getF64FloatAttr(M_PI)); + Value basic = + rewriter.create(loc, rewriter.getF64FloatAttr(180.0)); + Value rad = + rewriter.create(loc, op.getType(), self, basic); + Value result = rewriter.create(loc, op.getType(), rad, pi); + + rewriter.replaceOp(op, result); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -8640,6 +8678,71 @@ public: }; } // namespace +namespace { +class DecomposeAtenL1LossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenL1LossOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || !selfTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected self to be a tensor with sizes and a dtype"); + } + + Value target = op.getTarget(); + auto targetTy = dyn_cast(target.getType()); + if (!targetTy || !targetTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected target to be a tensor with sizes and a dtype"); + } + + auto outTy = dyn_cast(op.getType()); + if (!outTy || !outTy.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "Expected output type to be a tensor with a dtype"); + } + + auto outDtype = outTy.getDtype(); + if (selfTy.getDtype() != outDtype) { + self = convertTensorToDtype(rewriter, loc, self, outDtype); + } + if (targetTy.getDtype() != outDtype) { + target = convertTensorToDtype(rewriter, loc, target, outDtype); + } + + Value reduction = op.getReduction(); + int64_t reductionInt; + if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { + return rewriter.notifyMatchFailure( + op, "Expected reduction to be a constant int"); + } + + auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype); + Value sub = createTensorSub(rewriter, loc, subTy, self, target); + Value abs = rewriter.create(loc, subTy, sub); + + if (reductionInt == 0) { + rewriter.replaceOp(op, abs); + } else if (reductionInt == 1) { + Value none = rewriter.create(loc); + Value sum = rewriter.create(loc, outTy, abs, none); + Value numel = rewriter.create(loc, abs); + Value mean = rewriter.create(loc, outTy, sum, numel); + rewriter.replaceOp(op, mean); + } else { + Value none = rewriter.create(loc); + Value sum = rewriter.create(loc, outTy, abs, none); + rewriter.replaceOp(op, sum); + } + + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op class DecomposeAtenNormScalarOptDimOp @@ -10776,6 +10879,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -10821,6 +10925,7 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 4dd855be4..f868c4c18 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -527,6 +527,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -564,6 +565,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 18adad513..e0011b9a3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -701,7 +701,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", - "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", @@ -2899,6 +2898,7 @@ ONNX_XFAIL_SET = { "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", + "Deg2radModule_basic", "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", @@ -2986,6 +2986,9 @@ ONNX_XFAIL_SET = { "IsFloatingPointInt_False", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", + "L1LossNoReductionModule_basic", + "L1LossMeanReductionModule_basic", + "L1LossSumReductionModule_basic", "LeakyReluBackwardModule_basic", "LeakyReluBackwardStaticModule_basic", "LenStrModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8a9e7755e..8dfacca32 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2062,6 +2062,9 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti return [2, trapezoid_size + rectangle_size] +def aten〇deg2rad〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -2080,6 +2083,11 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = return upstream_shape_functions.unary(self) return [] +def aten〇l1_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: + if reduction == 0: + return upstream_shape_functions.unary(self) + return [] + def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) @@ -4262,6 +4270,15 @@ def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: assert not is_integer_dtype(promoted_dtype) return promoted_dtype +def aten〇l1_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + target_rank, target_dtype = target_rank_dtype + ranks: List[Optional[int]] = [self_rank, target_rank] + dtypes = [self_dtype, target_dtype] + promoted_dtype = promote_dtypes(ranks, dtypes) + assert not is_integer_dtype(promoted_dtype) + return promoted_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -5734,6 +5751,10 @@ def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Opti def aten〇tril_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: return torch.int64 if dtype is None else dtype +def aten〇deg2rad〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 0913b2c67..31916f7fe 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -747,6 +747,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit("aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)") emit( "aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)" ) @@ -1170,6 +1171,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): has_verifier=True, ) + emit("aten::deg2rad : (Tensor) -> (Tensor)") + # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a6679ec4d..38fccc06b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -7173,3 +7173,26 @@ class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module): @register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class Deg2radModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([3, 4], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.deg2rad(x) + + +@register_test_case(module_factory=lambda: Deg2radModule()) +def Deg2radModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 89774c5d1..3e379deac 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -2260,6 +2260,78 @@ def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): # ============================================================================== +class L1LossNoReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=0) + + +@register_test_case(module_factory=lambda: L1LossNoReductionModule()) +def L1LossNoReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + +class L1LossMeanReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=1) + + +@register_test_case(module_factory=lambda: L1LossMeanReductionModule()) +def L1LossMeanReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + +class L1LossSumReductionModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4], torch.float32, True), + ([2, 4], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.ops.aten.l1_loss(x, y, reduction=2) + + +@register_test_case(module_factory=lambda: L1LossSumReductionModule()) +def L1LossSumReductionModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4), tu.rand(2, 4)) + + +# ============================================================================== + + class CrossEntropyLossModule(torch.nn.Module): def __init__(self): super().__init__()