diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 816be2850..4dc96e7d0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1029,6 +1029,55 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [ }]; } +def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2111145a7..950f95974 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -447,12 +447,54 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); - if (!dtype.isa()) + if (!dtype.isa()) { div.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); return b.create(loc, lhs, rhs); } + if (auto divTensorMode = dyn_cast(op)) { + AtenDivTensorModeOp::Adaptor adaptor(operands); + Type dtype = converter->convertType(divTensorMode.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + divTensorMode.emitError("unimplemented: non-floating point dtype"); + return nullptr; + } + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value div = b.create(loc, lhs, rhs); + + if (divTensorMode.rounding_mode().getType().isa()) + return div; + + std::string roundingMode; + if (!matchPattern(divTensorMode.rounding_mode(), + m_TorchConstantStr(roundingMode))) { + divTensorMode.emitError("only support constant str rounding mode"); + return nullptr; + } + if (roundingMode == "trunc") { + // "trunc" - rounds the results of the division towards zero. Equivalent + // to C-style integer division. + Value ceil = b.create(loc, div); + Value floor = b.create(loc, div); + Value cstZero = b.create(loc, b.getZeroAttr(dtype)); + Value pred = + b.create(loc, arith::CmpFPredicate::ULT, div, cstZero); + return b.create(loc, pred, ceil, floor); + } + if (roundingMode == "floor") { + // "floor" - rounds the results of the division down. Equivalent to + // floor division in Python (the // operator) + return b.create(loc, div); + } + divTensorMode.emitError("invalid rounding mode"); + return nullptr; + } if (auto pow = dyn_cast(op)) { if (!pow.getType() .cast() @@ -845,17 +887,17 @@ public: ConversionPatternRewriter &rewriter) const override { if (!isa(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -1585,15 +1627,15 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( MLIRContext *context = patterns.getContext(); target.addIllegalOp< AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, - AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp, - AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp, - AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, - AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op, - AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp, - AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, + AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, + AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, + AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, + AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, + AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, + AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, + AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, + AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 773fc3a71..a1c5e8470 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -701,8 +701,8 @@ ChangeResult TypeAnalyzer::visitOperation( // Promote the two dtypes assuming possibly-zero rank. if (isa(op)) { + AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp, + AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) { auto knowledge = ValueKnowledge::getTensorPessimisticValueState(op->getContext()); knowledge.dtype = getPromotedResultType( diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index a820f4693..6cf81190a 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -2196,6 +2196,10 @@ module { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list } + func.func @"__torch_mlir_shape_fn.aten.div.Tensor_mode"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional) -> !torch.list { + %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list + return %0 : !torch.list + } func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list { %0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list return %0 : !torch.list diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index a90d6fec0..9766fd426 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -711,6 +711,9 @@ def aten〇mul〇Tensor(self: List[int], other: List[int]) -> List[int]: def aten〇div〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_helpers.broadcast(self, other) +def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]: + return upstream_shape_helpers.broadcast(self, other) + def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]: return upstream_shape_helpers.broadcast(self, other) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 736259d76..e49a5c35b 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -250,6 +250,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", + "aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", diff --git a/python/torch_mlir_e2e_test/test_suite/elementwise.py b/python/torch_mlir_e2e_test/test_suite/elementwise.py index 35bd261ac..a2a492c6a 100644 --- a/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -18,7 +18,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export # ============================================================================== + class ElementwiseUnaryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -35,9 +37,12 @@ class ElementwiseUnaryModule(torch.nn.Module): def ElementwiseUnaryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseUnaryIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -54,9 +59,12 @@ class ElementwiseUnaryIntModule(torch.nn.Module): def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseBinaryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -74,9 +82,12 @@ class ElementwiseBinaryModule(torch.nn.Module): def ElementwiseBinaryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(4)) + # ============================================================================== + class ElementwiseBinaryStaticShapeModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -95,9 +106,12 @@ class ElementwiseBinaryStaticShapeModule(torch.nn.Module): def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils): module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2)) + # ============================================================================== + class ElementwiseTernaryModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -116,9 +130,12 @@ class ElementwiseTernaryModule(torch.nn.Module): def ElementwiseTernaryModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5)) + # ============================================================================== + class ElementwiseWhereSelfModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -137,9 +154,12 @@ class ElementwiseWhereSelfModule(torch.nn.Module): def ElementwiseWhereSelfModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5)) + # ============================================================================== + class ElementwiseWhereScalarModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -156,9 +176,12 @@ class ElementwiseWhereScalarModule(torch.nn.Module): def ElementwiseWhereScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5)) + # ============================================================================== + class ElementwiseWhereScalarOtherModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -176,9 +199,12 @@ class ElementwiseWhereScalarOtherModule(torch.nn.Module): def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) + # ============================================================================== + class ElementwiseWhereScalarSelfModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -196,11 +222,14 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module): def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double()) + # ============================================================================== + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -218,9 +247,12 @@ class ElementwiseAddModule(torch.nn.Module): def ElementwiseAddModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) + # ============================================================================== + class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -239,9 +271,12 @@ class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module): def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand()) + # ============================================================================== + class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -261,9 +296,12 @@ class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module): def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 3)) + # ============================================================================== + class ElementwiseFlattenBroadcastModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -281,9 +319,12 @@ class ElementwiseFlattenBroadcastModule(torch.nn.Module): def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils): module.forward(tu.rand(6), tu.rand()) + # ============================================================================== + class ElementwiseReluModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -300,9 +341,12 @@ class ElementwiseReluModule(torch.nn.Module): def ElementwiseReluModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 2) - 0.5) + # ============================================================================== + class ElementwiseLeakyReluModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -319,9 +363,12 @@ class ElementwiseLeakyReluModule(torch.nn.Module): def ElementwiseLeakyReluModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 2) - 0.5) + # ============================================================================== + class ElementwiseGeluModule(torch.nn.Module): + def __init__(self): super().__init__() self.gelu = torch.nn.GELU() @@ -339,9 +386,12 @@ class ElementwiseGeluModule(torch.nn.Module): def ElementwiseGeluModule_basic(module, tu: TestUtils): module.forward(2 * tu.rand(5, 3) - 0.5) + # ============================================================================== + class ElementwiseSigmoidModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -358,9 +408,12 @@ class ElementwiseSigmoidModule(torch.nn.Module): def ElementwiseSigmoidModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseSigmoidIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -377,9 +430,12 @@ class ElementwiseSigmoidIntModule(torch.nn.Module): def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32)) + # ============================================================================== + class ElementwiseMinimumModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -397,9 +453,12 @@ class ElementwiseMinimumModule(torch.nn.Module): def ElementwiseMinimumModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) + # ============================================================================== + class ElementwiseMinimumIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -417,9 +476,12 @@ class ElementwiseMinimumIntModule(torch.nn.Module): def ElementwiseMinimumIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + # ============================================================================== + class ElementwiseMaximumModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -437,9 +499,12 @@ class ElementwiseMaximumModule(torch.nn.Module): def ElementwiseMaximumModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5), tu.rand(3, 5)) + # ============================================================================== + class ElementwiseMaximumIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -457,9 +522,12 @@ class ElementwiseMaximumIntModule(torch.nn.Module): def ElementwiseMaximumIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5))) + # ============================================================================== + class ElementwiseClampModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -531,9 +599,12 @@ class ElementwiseClampMaxModule(torch.nn.Module): def ElementwiseClampMaxModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5, low=-10, high=10)) + # ============================================================================== + class RsubModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -550,9 +621,12 @@ class RsubModule(torch.nn.Module): def RsubModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class RsubModule_noalpha(torch.nn.Module): + def __init__(self): super().__init__() @@ -569,9 +643,12 @@ class RsubModule_noalpha(torch.nn.Module): def RsubModule_noalpha_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseMulScalarIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -588,9 +665,12 @@ class ElementwiseMulScalarIntModule(torch.nn.Module): def ElementwiseMulScalarModule_int(module, tu: TestUtils): module.forward(torch.randint(10, (3, 4))) + # ============================================================================== + class ElementwiseMulScalarFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -607,9 +687,12 @@ class ElementwiseMulScalarFloatModule(torch.nn.Module): def ElementwiseMulScalarModule_float(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseMulScalarModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -626,9 +709,12 @@ class ElementwiseMulScalarModule(torch.nn.Module): def ElementwiseMulScalarModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseMulTensorFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -646,9 +732,12 @@ class ElementwiseMulTensorFloatModule(torch.nn.Module): def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) + # ============================================================================== + class ElementwiseMulTensorIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -667,9 +756,12 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils): module.forward( torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4])) + # ============================================================================== + class ElementwiseLogModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -686,9 +778,12 @@ class ElementwiseLogModule(torch.nn.Module): def ElementwiseLogModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseLogIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -705,9 +800,12 @@ class ElementwiseLogIntModule(torch.nn.Module): def ElementwiseLogIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseErfModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -724,9 +822,12 @@ class ElementwiseErfModule(torch.nn.Module): def ElementwiseErfModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseErfIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -743,10 +844,12 @@ class ElementwiseErfIntModule(torch.nn.Module): def ElementwiseErfIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== class ElementwiseSqrtModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -755,7 +858,6 @@ class ElementwiseSqrtModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.sqrt(a) @@ -764,9 +866,12 @@ class ElementwiseSqrtModule(torch.nn.Module): def ElementwiseSqrtModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseSqrtIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -775,7 +880,6 @@ class ElementwiseSqrtIntModule(torch.nn.Module): None, ([-1, -1], torch.int32, True), ]) - def forward(self, a): return torch.sqrt(a) @@ -784,17 +888,20 @@ class ElementwiseSqrtIntModule(torch.nn.Module): def ElementwiseSqrtIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseFloorModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.floor(a) @@ -803,17 +910,20 @@ class ElementwiseFloorModule(torch.nn.Module): def ElementwiseFloorModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseCeilModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.ceil(a) @@ -822,17 +932,20 @@ class ElementwiseCeilModule(torch.nn.Module): def ElementwiseCeilModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwisePowModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.pow(a, 2.0) @@ -841,17 +954,17 @@ class ElementwisePowModule(torch.nn.Module): def ElementwisePowModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): + def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): return x.to(torch.int64) @@ -860,17 +973,17 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module): def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseToDtypeIdentityModule(torch.nn.Module): + def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1], torch.float32, True) - ]) + @annotate_args([None, ([-1, -1], torch.float32, True)]) def forward(self, x): return x.to(torch.float32, False, False) @@ -879,9 +992,12 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module): def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 5)) + # ============================================================================== + class ElementwiseLog2Module(torch.nn.Module): + def __init__(self): super().__init__() @@ -898,9 +1014,12 @@ class ElementwiseLog2Module(torch.nn.Module): def ElementwiseLog2Module_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseLog2IntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -917,9 +1036,12 @@ class ElementwiseLog2IntModule(torch.nn.Module): def ElementwiseLog2IntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseRsqrtModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -928,7 +1050,6 @@ class ElementwiseRsqrtModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.rsqrt(a) @@ -937,9 +1058,12 @@ class ElementwiseRsqrtModule(torch.nn.Module): def ElementwiseRsqrtModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseRsqrtIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -948,7 +1072,6 @@ class ElementwiseRsqrtIntModule(torch.nn.Module): None, ([-1, -1], torch.int32, True), ]) - def forward(self, a): return torch.rsqrt(a) @@ -957,17 +1080,20 @@ class ElementwiseRsqrtIntModule(torch.nn.Module): def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseAbsModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), ]) - def forward(self, a): return torch.abs(a) @@ -976,17 +1102,20 @@ class ElementwiseAbsModule(torch.nn.Module): def ElementwiseAbsModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0)) + # ============================================================================== + class ElementwiseReciprocalModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1], torch.float32, True), ]) - def forward(self, a): return torch.reciprocal(a) @@ -995,17 +1124,20 @@ class ElementwiseReciprocalModule(torch.nn.Module): def ElementwiseReciprocalModule_basic(module, tu: TestUtils): module.forward(tu.rand(4)) + # ============================================================================== + class ElementwiseReciprocalIntModule(torch.nn.Module): + def __init__(self): super().__init__() + @export @annotate_args([ None, ([-1], torch.int32, True), ]) - def forward(self, a): return torch.reciprocal(a) @@ -1014,9 +1146,12 @@ class ElementwiseReciprocalIntModule(torch.nn.Module): def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (4,), dtype=torch.int32)) + # ============================================================================== + class ElementwiseDivScalarModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1033,9 +1168,12 @@ class ElementwiseDivScalarModule(torch.nn.Module): def ElementwiseDivScalarModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseDivTensorFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1053,9 +1191,57 @@ class ElementwiseDivTensorFloatModule(torch.nn.Module): def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) + # ============================================================================== + +class ElementwiseDivRoundingModeTruncModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="trunc") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeTruncModule()) +def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(4).type(torch.float64)) + + +class ElementwiseDivRoundingModeFloorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float64, True), + ]) + def forward(self, a, b): + return torch.div(a, b, rounding_mode="floor") + + +@register_test_case( + module_factory=lambda: ElementwiseDivRoundingModeFloorModule()) +def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64)) + + +# ============================================================================== + + class ElementwiseAndIntegerModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1075,9 +1261,12 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils): torch.randint(-10, 10, (3, 4)).to(torch.int32), torch.randint(-10, 10, (3, 4))) + # ============================================================================== + class ElementwiseSubScalarIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1094,9 +1283,12 @@ class ElementwiseSubScalarIntModule(torch.nn.Module): def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseSubScalarFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1113,9 +1305,12 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module): def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseAddScalarInt64Module(torch.nn.Module): + def __init__(self): super().__init__() @@ -1132,9 +1327,12 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module): def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils): module.forward(torch.randint(10, (3, 4))) + # ============================================================================== + class ElementwiseAddScalarIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1151,9 +1349,12 @@ class ElementwiseAddScalarIntModule(torch.nn.Module): def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(10, (2, 3), dtype=torch.int32)) + # ============================================================================== + class ElementwiseAddScalarFloatModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1170,9 +1371,12 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module): def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseCloneModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1189,9 +1393,12 @@ class ElementwiseCloneModule(torch.nn.Module): def ElementwiseCloneModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class ElementwiseCloneContiguousModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1208,9 +1415,12 @@ class ElementwiseCloneContiguousModule(torch.nn.Module): def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) + # ============================================================================== + class ElementwiseExpModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1219,7 +1429,6 @@ class ElementwiseExpModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.exp(a) @@ -1228,9 +1437,12 @@ class ElementwiseExpModule(torch.nn.Module): def ElementwiseExpModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseExpIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1239,7 +1451,6 @@ class ElementwiseExpIntModule(torch.nn.Module): None, ([-1, -1], torch.int32, True), ]) - def forward(self, a): return torch.exp(a) @@ -1251,7 +1462,9 @@ def ElementwiseExpIntModule_basic(module, tu: TestUtils): # ============================================================================== + class ElementwiseSinModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1260,7 +1473,6 @@ class ElementwiseSinModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.sin(a) @@ -1269,9 +1481,12 @@ class ElementwiseSinModule(torch.nn.Module): def ElementwiseSinModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseSinIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1280,7 +1495,6 @@ class ElementwiseSinIntModule(torch.nn.Module): None, ([-1, -1], torch.int32, True), ]) - def forward(self, a): return torch.sin(a) @@ -1289,9 +1503,12 @@ class ElementwiseSinIntModule(torch.nn.Module): def ElementwiseSinIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseCosModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1300,7 +1517,6 @@ class ElementwiseCosModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.cos(a) @@ -1309,9 +1525,12 @@ class ElementwiseCosModule(torch.nn.Module): def ElementwiseCosModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4)) + # ============================================================================== + class ElementwiseCosIntModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1320,7 +1539,6 @@ class ElementwiseCosIntModule(torch.nn.Module): None, ([-1, -1], torch.int32, True), ]) - def forward(self, a): return torch.cos(a) @@ -1329,9 +1547,12 @@ class ElementwiseCosIntModule(torch.nn.Module): def ElementwiseCosIntModule_basic(module, tu: TestUtils): module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32)) + # ============================================================================== + class ElementwiseNegModule(torch.nn.Module): + def __init__(self): super().__init__() @@ -1340,7 +1561,6 @@ class ElementwiseNegModule(torch.nn.Module): None, ([-1, -1], torch.float32, True), ]) - def forward(self, a): return torch.neg(a)