mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.div.Tensor_mode op
This commit adds lowering of `aten.div.Tensor_mode` op. This commit also fixes formatting for the test file elementwise.py. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/908/head
parent
a11ef674a7
commit
b95b3d844d
|
@ -1029,6 +1029,55 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchOptionalStringType:$rounding_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenDivTensorModeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other,
|
||||
AnyTorchOptionalStringType:$rounding_mode
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -447,12 +447,54 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Type dtype = converter->convertType(div.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>())
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
div.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto divTensorMode = dyn_cast<AtenDivTensorModeOp>(op)) {
|
||||
AtenDivTensorModeOp::Adaptor adaptor(operands);
|
||||
Type dtype = converter->convertType(divTensorMode.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
divTensorMode.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
Value div = b.create<arith::DivFOp>(loc, lhs, rhs);
|
||||
|
||||
if (divTensorMode.rounding_mode().getType().isa<Torch::NoneType>())
|
||||
return div;
|
||||
|
||||
std::string roundingMode;
|
||||
if (!matchPattern(divTensorMode.rounding_mode(),
|
||||
m_TorchConstantStr(roundingMode))) {
|
||||
divTensorMode.emitError("only support constant str rounding mode");
|
||||
return nullptr;
|
||||
}
|
||||
if (roundingMode == "trunc") {
|
||||
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||
// to C-style integer division.
|
||||
Value ceil = b.create<math::CeilOp>(loc, div);
|
||||
Value floor = b.create<math::FloorOp>(loc, div);
|
||||
Value cstZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(dtype));
|
||||
Value pred =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT, div, cstZero);
|
||||
return b.create<arith::SelectOp>(loc, pred, ceil, floor);
|
||||
}
|
||||
if (roundingMode == "floor") {
|
||||
// "floor" - rounds the results of the division down. Equivalent to
|
||||
// floor division in Python (the // operator)
|
||||
return b.create<math::FloorOp>(loc, div);
|
||||
}
|
||||
divTensorMode.emitError("invalid rounding mode");
|
||||
return nullptr;
|
||||
}
|
||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||
if (!pow.getType()
|
||||
.cast<ValueTensorType>()
|
||||
|
@ -845,17 +887,17 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
|
||||
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
|
||||
AtenDivTensorOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp,
|
||||
AtenExpOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp,
|
||||
AtenClampOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
|
||||
AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
|
@ -1585,15 +1627,15 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
MLIRContext *context = patterns.getContext();
|
||||
target.addIllegalOp<
|
||||
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenSubTensorOp,
|
||||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenMaximumOp,
|
||||
AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
|
||||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
|
||||
AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
|
|
|
@ -701,8 +701,8 @@ ChangeResult TypeAnalyzer::visitOperation(
|
|||
|
||||
// Promote the two dtypes assuming possibly-zero rank.
|
||||
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
|
||||
Aten__And__TensorOp, AtenMinimumOp, AtenMaximumOp,
|
||||
AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
|
||||
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = getPromotedResultType(
|
||||
|
|
|
@ -2196,6 +2196,10 @@ module {
|
|||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.div.Tensor_mode"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<str>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -711,6 +711,9 @@ def aten〇mul〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
|||
def aten〇div〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(self, other)
|
||||
|
||||
def aten〇div〇Tensor_mode(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(self, other)
|
||||
|
||||
def aten〇__and__〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_helpers.broadcast(self, other)
|
||||
|
||||
|
|
|
@ -250,6 +250,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::logical_or : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)",
|
||||
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
|
|
|
@ -18,7 +18,9 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnaryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -35,9 +37,12 @@ class ElementwiseUnaryModule(torch.nn.Module):
|
|||
def ElementwiseUnaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnaryIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -54,9 +59,12 @@ class ElementwiseUnaryIntModule(torch.nn.Module):
|
|||
def ElementwiseUnaryIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseBinaryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -74,9 +82,12 @@ class ElementwiseBinaryModule(torch.nn.Module):
|
|||
def ElementwiseBinaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -95,9 +106,12 @@ class ElementwiseBinaryStaticShapeModule(torch.nn.Module):
|
|||
def ElementwiseBinaryStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 4, 3, 3, 1), tu.rand(4, 3, 1, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseTernaryModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -116,9 +130,12 @@ class ElementwiseTernaryModule(torch.nn.Module):
|
|||
def ElementwiseTernaryModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereSelfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -137,9 +154,12 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
|
|||
def ElementwiseWhereSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5), tu.rand(5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -156,9 +176,12 @@ class ElementwiseWhereScalarModule(torch.nn.Module):
|
|||
def ElementwiseWhereScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarOtherModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -176,9 +199,12 @@ class ElementwiseWhereScalarOtherModule(torch.nn.Module):
|
|||
def ElementwiseWhereScalarOtherModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -196,11 +222,14 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
|||
def ElementwiseWhereScalarSelfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).double(), tu.rand(4, 5).double())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# Addition is an interesting special case of a binary op, because under the hood
|
||||
# it carries a third scalar "alpha" parameter, which needs special handling.
|
||||
class ElementwiseAddModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -218,9 +247,12 @@ class ElementwiseAddModule(torch.nn.Module):
|
|||
def ElementwiseAddModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -239,9 +271,12 @@ class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
|||
def ElementwiseUnsqueezeBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -261,9 +296,12 @@ class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
|||
def ElementwiseUnsqueezeNegDimsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -281,9 +319,12 @@ class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
|||
def ElementwiseFlattenBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6), tu.rand())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseReluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -300,9 +341,12 @@ class ElementwiseReluModule(torch.nn.Module):
|
|||
def ElementwiseReluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLeakyReluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -319,9 +363,12 @@ class ElementwiseLeakyReluModule(torch.nn.Module):
|
|||
def ElementwiseLeakyReluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 2) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseGeluModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gelu = torch.nn.GELU()
|
||||
|
@ -339,9 +386,12 @@ class ElementwiseGeluModule(torch.nn.Module):
|
|||
def ElementwiseGeluModule_basic(module, tu: TestUtils):
|
||||
module.forward(2 * tu.rand(5, 3) - 0.5)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -358,9 +408,12 @@ class ElementwiseSigmoidModule(torch.nn.Module):
|
|||
def ElementwiseSigmoidModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSigmoidIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -377,9 +430,12 @@ class ElementwiseSigmoidIntModule(torch.nn.Module):
|
|||
def ElementwiseSigmoidIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 5), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMinimumModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -397,9 +453,12 @@ class ElementwiseMinimumModule(torch.nn.Module):
|
|||
def ElementwiseMinimumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMinimumIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -417,9 +476,12 @@ class ElementwiseMinimumIntModule(torch.nn.Module):
|
|||
def ElementwiseMinimumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMaximumModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -437,9 +499,12 @@ class ElementwiseMaximumModule(torch.nn.Module):
|
|||
def ElementwiseMaximumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMaximumIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -457,9 +522,12 @@ class ElementwiseMaximumIntModule(torch.nn.Module):
|
|||
def ElementwiseMaximumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseClampModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -531,9 +599,12 @@ class ElementwiseClampMaxModule(torch.nn.Module):
|
|||
def ElementwiseClampMaxModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5, low=-10, high=10))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -550,9 +621,12 @@ class RsubModule(torch.nn.Module):
|
|||
def RsubModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class RsubModule_noalpha(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -569,9 +643,12 @@ class RsubModule_noalpha(torch.nn.Module):
|
|||
def RsubModule_noalpha_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulScalarIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -588,9 +665,12 @@ class ElementwiseMulScalarIntModule(torch.nn.Module):
|
|||
def ElementwiseMulScalarModule_int(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -607,9 +687,12 @@ class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
|||
def ElementwiseMulScalarModule_float(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -626,9 +709,12 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -646,9 +732,12 @@ class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
|||
def ElementwiseMulTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseMulTensorIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -667,9 +756,12 @@ def ElementwiseMulTensorIntModule_basic(module, tu: TestUtils):
|
|||
module.forward(
|
||||
torch.randint(10, [4]).type(torch.int32), torch.randint(10, [4]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLogModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -686,9 +778,12 @@ class ElementwiseLogModule(torch.nn.Module):
|
|||
def ElementwiseLogModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLogIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -705,9 +800,12 @@ class ElementwiseLogIntModule(torch.nn.Module):
|
|||
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseErfModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -724,9 +822,12 @@ class ElementwiseErfModule(torch.nn.Module):
|
|||
def ElementwiseErfModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseErfIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -743,10 +844,12 @@ class ElementwiseErfIntModule(torch.nn.Module):
|
|||
def ElementwiseErfIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSqrtModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -755,7 +858,6 @@ class ElementwiseSqrtModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
||||
|
@ -764,9 +866,12 @@ class ElementwiseSqrtModule(torch.nn.Module):
|
|||
def ElementwiseSqrtModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSqrtIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -775,7 +880,6 @@ class ElementwiseSqrtIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
||||
|
@ -784,17 +888,20 @@ class ElementwiseSqrtIntModule(torch.nn.Module):
|
|||
def ElementwiseSqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseFloorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.floor(a)
|
||||
|
||||
|
@ -803,17 +910,20 @@ class ElementwiseFloorModule(torch.nn.Module):
|
|||
def ElementwiseFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCeilModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ceil(a)
|
||||
|
||||
|
@ -822,17 +932,20 @@ class ElementwiseCeilModule(torch.nn.Module):
|
|||
def ElementwiseCeilModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwisePowModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.pow(a, 2.0)
|
||||
|
||||
|
@ -841,17 +954,17 @@ class ElementwisePowModule(torch.nn.Module):
|
|||
def ElementwisePowModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -860,17 +973,17 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
|||
def ElementwiseToDtypeF32ToI64Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True)
|
||||
])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32, False, False)
|
||||
|
||||
|
@ -879,9 +992,12 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
|||
def ElementwiseToDtypeIdentityModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLog2Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -898,9 +1014,12 @@ class ElementwiseLog2Module(torch.nn.Module):
|
|||
def ElementwiseLog2Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLog2IntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -917,9 +1036,12 @@ class ElementwiseLog2IntModule(torch.nn.Module):
|
|||
def ElementwiseLog2IntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRsqrtModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -928,7 +1050,6 @@ class ElementwiseRsqrtModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
||||
|
@ -937,9 +1058,12 @@ class ElementwiseRsqrtModule(torch.nn.Module):
|
|||
def ElementwiseRsqrtModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRsqrtIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -948,7 +1072,6 @@ class ElementwiseRsqrtIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
||||
|
@ -957,17 +1080,20 @@ class ElementwiseRsqrtIntModule(torch.nn.Module):
|
|||
def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAbsModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.abs(a)
|
||||
|
||||
|
@ -976,17 +1102,20 @@ class ElementwiseAbsModule(torch.nn.Module):
|
|||
def ElementwiseAbsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseReciprocalModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
||||
|
@ -995,17 +1124,20 @@ class ElementwiseReciprocalModule(torch.nn.Module):
|
|||
def ElementwiseReciprocalModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseReciprocalIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
||||
|
@ -1014,9 +1146,12 @@ class ElementwiseReciprocalIntModule(torch.nn.Module):
|
|||
def ElementwiseReciprocalIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (4,), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1033,9 +1168,12 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1053,9 +1191,57 @@ class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
|||
def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeTruncModule())
|
||||
def ElementwiseDivRoundingModeTruncModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4), tu.rand(4).type(torch.float64))
|
||||
|
||||
|
||||
class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="floor")
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: ElementwiseDivRoundingModeFloorModule())
|
||||
def ElementwiseDivRoundingModeFloorModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4), tu.rand(3, 4).type(torch.float64))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAndIntegerModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1075,9 +1261,12 @@ def ElementwiseAndIntegerModule_basic(module, tu: TestUtils):
|
|||
torch.randint(-10, 10, (3, 4)).to(torch.int32),
|
||||
torch.randint(-10, 10, (3, 4)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSubScalarIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1094,9 +1283,12 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1113,9 +1305,12 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
|||
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1132,9 +1327,12 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
|||
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1151,9 +1349,12 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|||
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1170,9 +1371,12 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
|||
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCloneModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1189,9 +1393,12 @@ class ElementwiseCloneModule(torch.nn.Module):
|
|||
def ElementwiseCloneModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCloneContiguousModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1208,9 +1415,12 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
|||
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseExpModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1219,7 +1429,6 @@ class ElementwiseExpModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
||||
|
@ -1228,9 +1437,12 @@ class ElementwiseExpModule(torch.nn.Module):
|
|||
def ElementwiseExpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseExpIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1239,7 +1451,6 @@ class ElementwiseExpIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
||||
|
@ -1251,7 +1462,9 @@ def ElementwiseExpIntModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSinModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1260,7 +1473,6 @@ class ElementwiseSinModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
||||
|
@ -1269,9 +1481,12 @@ class ElementwiseSinModule(torch.nn.Module):
|
|||
def ElementwiseSinModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseSinIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1280,7 +1495,6 @@ class ElementwiseSinIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
||||
|
@ -1289,9 +1503,12 @@ class ElementwiseSinIntModule(torch.nn.Module):
|
|||
def ElementwiseSinIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCosModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1300,7 +1517,6 @@ class ElementwiseCosModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
||||
|
@ -1309,9 +1525,12 @@ class ElementwiseCosModule(torch.nn.Module):
|
|||
def ElementwiseCosModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseCosIntModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1320,7 +1539,6 @@ class ElementwiseCosIntModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
||||
|
@ -1329,9 +1547,12 @@ class ElementwiseCosIntModule(torch.nn.Module):
|
|||
def ElementwiseCosIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseNegModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1340,7 +1561,6 @@ class ElementwiseNegModule(torch.nn.Module):
|
|||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.neg(a)
|
||||
|
||||
|
|
Loading…
Reference in New Issue