mirror of https://github.com/llvm/torch-mlir
parent
9cac480a18
commit
d9cbf01d1e
|
@ -12,12 +12,7 @@
|
|||
|
||||
from torch_mlir_e2e_test.test_suite import COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
}
|
||||
REFBACKEND_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS
|
||||
|
||||
EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
||||
# RefBackend fails for some reason.
|
||||
|
@ -25,10 +20,6 @@ EAGER_MODE_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
# why they fail here.
|
||||
"Matmul_vecmat",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
|
@ -299,6 +290,8 @@ MHLO_PASS_SET = {
|
|||
"TModuleRank1_basic",
|
||||
"TModuleRank0_basic",
|
||||
"ElementwiseToDtypeIdentityModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubIntModule_basic",
|
||||
|
@ -363,6 +356,7 @@ MHLO_PASS_SET = {
|
|||
"ContiguousModule_basic",
|
||||
"DropoutModule_basic",
|
||||
"ViewCollapseModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewDynamicExpandCollapseModule_basic",
|
||||
"ViewDynamicExpandModule_basic",
|
||||
"ViewExpandModule_basic",
|
||||
|
@ -371,6 +365,7 @@ MHLO_PASS_SET = {
|
|||
"ViewExpandOnesMiddleModule_basic",
|
||||
"ViewExpandCollapseModule_basic",
|
||||
"ViewExpandCollapseWithOnesModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"ViewNoChange1dModule_basic",
|
||||
"ViewNoChange2dModule_basic",
|
||||
|
@ -432,7 +427,6 @@ MHLO_PASS_SET = {
|
|||
"UnsafeViewCollapseModule_basic",
|
||||
"UnsafeViewDynamicExpandModule_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
}
|
||||
|
||||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
|
@ -478,6 +472,8 @@ TOSA_PASS_SET = {
|
|||
"TModuleRank0_basic",
|
||||
"ElementwiseToDtypeIdentityModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"SqueezeDimModule_static",
|
||||
"SqueezeDimModule_identity",
|
||||
"SqueezeDimModule_unitDim",
|
||||
|
@ -574,6 +570,8 @@ TOSA_PASS_SET = {
|
|||
"ViewExpandOnesMiddleModule_basic",
|
||||
"ViewExpandCollapseModule_basic",
|
||||
"ViewExpandCollapseWithOnesModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
"ViewNoChangeStaticModule_basic",
|
||||
"UnsafeViewExpandModule_basic",
|
||||
"ReshapeCollapseModule_basic",
|
||||
|
@ -772,8 +770,4 @@ LTC_XFAIL_SET = {
|
|||
"VarMeanCorrectionModule_basic",
|
||||
"VarMeanCorrectionNoneModule_basic",
|
||||
"PrimsConvertElementTypeModule_basic",
|
||||
"UnsafeView1DFoldModule_basic",
|
||||
"View1DFoldModule_basic",
|
||||
"ViewCollapseInferredDimModule_basic",
|
||||
"ViewExpandInferredDimModule_basic",
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 147fe9de29dc13c14835127b35280c4d95c8e8ba
|
||||
Subproject commit e864ac694540342d5e59f59c525c5082f2594fb8
|
|
@ -1 +1 @@
|
|||
Subproject commit 1944b5fa6062ec4c065d726c9c5d64f1487ee8c5
|
||||
Subproject commit eab364ba2a66bd0613efb94f8a738c1c97aaee92
|
|
@ -35,7 +35,7 @@ using GetTensorTypeFn =
|
|||
llvm::function_ref<Type(MLIRContext *, Optional<ArrayRef<int64_t>>, Type)>;
|
||||
|
||||
/// The representation of an unknown dimension size in an ArrayRef<int64_t>.
|
||||
constexpr static int64_t kUnknownSize = ShapedType::kDynamicSize;
|
||||
constexpr static int64_t kUnknownSize = -1;
|
||||
|
||||
class BaseTensorType : public Type {
|
||||
public:
|
||||
|
|
|
@ -299,14 +299,14 @@ public:
|
|||
int64_t outputDynamicValues = 0;
|
||||
|
||||
for (int64_t value : inputShape) {
|
||||
if (value == kUnknownSize) {
|
||||
if (value == -1) {
|
||||
++inputDynamicValues;
|
||||
} else {
|
||||
inputProduct *= value;
|
||||
}
|
||||
}
|
||||
for (int64_t value : outputShape) {
|
||||
if (value == kUnknownSize) {
|
||||
if (value == -1) {
|
||||
++outputDynamicValues;
|
||||
} else {
|
||||
outputProduct *= value;
|
||||
|
@ -317,7 +317,7 @@ public:
|
|||
if (inputDynamicValues) {
|
||||
int64_t missingValue = outputProduct / inputProduct;
|
||||
for (size_t i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] == kUnknownSize) {
|
||||
if (inputShape[i] == -1) {
|
||||
inputShape[i] = missingValue;
|
||||
break;
|
||||
}
|
||||
|
@ -325,7 +325,7 @@ public:
|
|||
} else {
|
||||
int64_t missingValue = inputProduct / outputProduct;
|
||||
for (size_t i = 0; i < outputShape.size(); i++) {
|
||||
if (outputShape[i] == kUnknownSize) {
|
||||
if (outputShape[i] == -1) {
|
||||
outputShape[i] = missingValue;
|
||||
break;
|
||||
}
|
||||
|
@ -415,7 +415,7 @@ public:
|
|||
|
||||
if (inferredDimension.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "at most one element in size list is allowed to be kUnknownSize");
|
||||
op, "at most one element in size list is allowed to be -1");
|
||||
}
|
||||
inferredDimension = outputDim;
|
||||
}
|
||||
|
@ -652,7 +652,7 @@ public:
|
|||
|
||||
SmallVector<int64_t> intermediateShape;
|
||||
for (auto i : llvm::seq(0, (int)outputAssociations.size())) {
|
||||
int64_t sum = 1;
|
||||
int sum = 1;
|
||||
|
||||
for (auto j : llvm::seq(0, (int)outputAssociations[i].size())) {
|
||||
if (outputShape[outputAssociations[i][j]] < 0) {
|
||||
|
|
|
@ -608,7 +608,7 @@ public:
|
|||
auto resultDimSize = refinedResultShape[i];
|
||||
if (ShapedType::isDynamic(resultDimSize)) {
|
||||
SmallVector<Value> dynamicDims;
|
||||
int64_t staticDimSize = kUnknownSize;
|
||||
int64_t staticDimSize = -1;
|
||||
for (auto indexTensor : indexTensors) {
|
||||
RankedTensorType indexTensorType =
|
||||
indexTensor.getType().cast<RankedTensorType>();
|
||||
|
|
|
@ -1900,11 +1900,8 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
|
|||
op, "Only constant shape supported in TOSA Reshape");
|
||||
|
||||
int auto_sz = 0;
|
||||
for (unsigned i = 0; i < newShape.size(); i++)
|
||||
if (newShape[i] == -1) {
|
||||
auto_sz += 1;
|
||||
newShape[i] = kUnknownSize;
|
||||
}
|
||||
for (auto s : newShape)
|
||||
auto_sz += (s == -1 ? 1 : 0);
|
||||
if (auto_sz > 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "At most one dimension may be specified as -1 to "
|
||||
|
|
|
@ -268,8 +268,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
|||
int64_t axis_val = axes_elems.getValues<IntegerAttr>()[i].getInt();
|
||||
if (axis_val < 0)
|
||||
axis_val += input_rank;
|
||||
int64_t inShape = input_type.getShape()[axis_val];
|
||||
num_elems_on_reduced_axis *= inShape < 0 ? -1 : inShape;
|
||||
num_elems_on_reduced_axis *= input_type.getShape()[axis_val];
|
||||
}
|
||||
double div_scale = 1.0 / static_cast<double>(num_elems_on_reduced_axis);
|
||||
|
||||
|
|
|
@ -320,7 +320,7 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
|
|||
// return -1.
|
||||
int64_t getNumberOfElements(RankedTensorType inputType) {
|
||||
if (!inputType.hasStaticShape())
|
||||
return kUnknownSize;
|
||||
return -1;
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
int64_t numel = 1;
|
||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||
|
|
|
@ -231,7 +231,7 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser,
|
|||
}
|
||||
}
|
||||
if (succeeded(parser.parseOptionalQuestion())) {
|
||||
sizes.push_back(kUnknownSize);
|
||||
sizes.push_back(-1);
|
||||
continue;
|
||||
}
|
||||
int64_t size;
|
||||
|
|
|
@ -974,8 +974,8 @@ public:
|
|||
newSizes.push_back(
|
||||
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dim));
|
||||
}
|
||||
Value flattenDimSize = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(kUnknownSize));
|
||||
Value flattenDimSize =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(-1));
|
||||
newSizes.push_back(flattenDimSize);
|
||||
for (int64_t k = end + 1; k < rank; ++k) {
|
||||
Value dim =
|
||||
|
|
|
@ -239,15 +239,14 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
|||
continue;
|
||||
// If the op is read-only and all results are safe, then this value is
|
||||
// safe. This covers, for example, view-like ops that create aliases.
|
||||
if (auto effects = dyn_cast<MemoryEffectOpInterface>(op)) {
|
||||
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() || effects.hasNoEffect()) &&
|
||||
llvm::all_of(op->getResults(), [&](Value result) {
|
||||
auto *state =
|
||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
|
||||
return state->isSafe;
|
||||
}))
|
||||
continue;
|
||||
}
|
||||
if ((op->hasTrait<Torch::OpTrait::ReadOnly>() ||
|
||||
MemoryEffectOpInterface::hasNoEffect(op)) &&
|
||||
llvm::all_of(op->getResults(), [&](Value result) {
|
||||
auto *state =
|
||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value, result);
|
||||
return state->isSafe;
|
||||
}))
|
||||
continue;
|
||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||
auto symName = initialize.slotSymNames()[use.getOperandNumber()]
|
||||
.cast<FlatSymbolRefAttr>();
|
||||
|
|
|
@ -15,9 +15,6 @@ from torch_mlir.passmanager import PassManager
|
|||
from .compiler_utils import run_pipeline_with_repro_report
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import ClassAnnotator, ImportOptions, ModuleBuilder
|
||||
|
||||
# This value is taken from `ShapedType::kDynamicSize` variable
|
||||
# defined here: "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
MLIR_DYNAMIC_STRIDE = -9223372036854775808
|
||||
|
||||
class OutputType(Enum):
|
||||
"""The kind of output that `torch_mlir.compile` can produce.
|
||||
|
@ -115,7 +112,7 @@ class TensorPlaceholder:
|
|||
shape = []
|
||||
for i, dim in enumerate(tensor.shape):
|
||||
if i in dynamic_axes:
|
||||
shape.append(MLIR_DYNAMIC_STRIDE)
|
||||
shape.append(-1)
|
||||
else:
|
||||
shape.append(dim)
|
||||
return TensorPlaceholder(shape, tensor.dtype)
|
||||
|
@ -224,7 +221,7 @@ class ExampleArgs:
|
|||
# tracing, they are walking on thin ice already -- assume
|
||||
# they know what they are doing and that their trace is
|
||||
# correct for any specific concrete size.
|
||||
shape = [s if s != MLIR_DYNAMIC_STRIDE else 7 for s in arg.shape]
|
||||
shape = [s if s != -1 else 7 for s in arg.shape]
|
||||
example_args_for_trace.append(
|
||||
torch.ones(*shape, dtype=arg.dtype))
|
||||
else:
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
#include "mlir_lowering_context.h"
|
||||
#include "mlir_node.h"
|
||||
#include "ops/device_data.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
@ -144,8 +143,7 @@ get_tensor_type_shape(c10::TensorType& tensor_type) {
|
|||
dims.resize(*symbolic_shape.rank());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto shape_symbol = symbolic_shape[i];
|
||||
dims[i] = shape_symbol.is_static() ? shape_symbol.static_size()
|
||||
: mlir::torch::Torch::kUnknownSize;
|
||||
dims[i] = shape_symbol.is_static() ? shape_symbol.static_size() : -1;
|
||||
}
|
||||
|
||||
return dims;
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "mlir-c/Diagnostics.h"
|
||||
#include "torch-mlir-c/TorchOps.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -163,8 +162,7 @@ torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
dims.resize(*sizes.rank());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size()
|
||||
: mlir::torch::Torch::kUnknownSize;
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
|
||||
// `std::vector`'s `.data()` method can return nullptr when the
|
||||
|
|
|
@ -17,7 +17,7 @@ class ArgmaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -37,7 +37,7 @@ class ArgmaxWithDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, dim=1)
|
||||
|
@ -55,7 +55,7 @@ class ArgmaxKeepDimsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.argmax(a, 0, True)
|
||||
|
|
|
@ -20,8 +20,8 @@ class SoftmaxBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._softmax_backward_data(grad_output,
|
||||
|
@ -44,8 +44,8 @@ class TanhBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, output):
|
||||
return torch.ops.aten.tanh_backward(grad_out, output)
|
||||
|
@ -66,9 +66,9 @@ class ConvolutionBackwardModule2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
|
@ -100,9 +100,9 @@ class ConvolutionBackwardModule2DPadded(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_out, input_vec, weight):
|
||||
return torch.ops.aten.convolution_backward(
|
||||
|
@ -137,8 +137,8 @@ class GeluBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad, input):
|
||||
return torch.ops.aten.gelu_backward(grad, input)
|
||||
|
@ -157,8 +157,8 @@ class LogSoftmaxBackwardModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, output):
|
||||
return torch.ops.aten._log_softmax_backward_data(grad_output,
|
||||
|
|
|
@ -20,8 +20,8 @@ class MmModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.mm(lhs, rhs)
|
||||
|
@ -49,8 +49,8 @@ class BmmModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.bmm(lhs, rhs)
|
||||
|
@ -72,7 +72,7 @@ class IsFloatingPointInt(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.is_floating_point(x)
|
||||
|
@ -94,7 +94,7 @@ class IsFloatingPointFloat(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.is_floating_point(x)
|
||||
|
@ -174,8 +174,8 @@ class MmTanhModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.tanh(self.matmul(lhs, rhs))
|
||||
|
@ -200,9 +200,9 @@ class AddmmModuleFloat(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, M, mat1, mat2):
|
||||
return torch.addmm(M, mat1, mat2, beta=3.0, alpha=7.0)
|
||||
|
@ -224,9 +224,9 @@ class AddmmModuleBroadcastable(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, M, mat1, mat2):
|
||||
return torch.addmm(M, mat1, mat2, beta=2.0, alpha=7.0)
|
||||
|
@ -248,9 +248,9 @@ class AddmmModuleDifferentRankBroadcastable(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, M, mat1, mat2):
|
||||
return torch.addmm(M, mat1, mat2, beta=11.0, alpha=7.0)
|
||||
|
@ -320,7 +320,7 @@ class FlattenDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, 9, 3, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, 9, 3, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.flat(x)
|
||||
|
@ -365,7 +365,7 @@ class PadModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
|
@ -389,7 +389,7 @@ class PadWithNoneValModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
pad = [0, 1, 2, 3]
|
||||
|
@ -413,7 +413,7 @@ class ConstantPadNdModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1), -float('inf'))
|
||||
|
@ -457,7 +457,7 @@ class ConstantPadNdPartialStaticModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 20, 20, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([1, 1, 20, 20, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.constant_pad_nd(x, (0, 1, 2, 3), -float('inf'))
|
||||
|
@ -580,9 +580,9 @@ class TensorsConcatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], 1)
|
||||
|
@ -604,9 +604,9 @@ class TensorsConcatNegativeDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
return torch.cat([x, y, z], dim=-2)
|
||||
|
@ -628,8 +628,8 @@ class GatherModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 2, indices)
|
||||
|
@ -650,8 +650,8 @@ class GatherRandomIndexModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
@ -671,8 +671,8 @@ class Gather2DInputModdule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, indices):
|
||||
return torch.gather(tensor, 1, indices)
|
||||
|
@ -716,7 +716,7 @@ class AddSizeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
# This is a workaround for not supporting scalar arguments.
|
||||
|
@ -741,7 +741,7 @@ class AddSizeIntNegDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
# This is a workaround for not supporting scalar arguments.
|
||||
|
@ -770,7 +770,7 @@ class EmbeddingModuleI64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
@ -796,7 +796,7 @@ class EmbeddingModuleI32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
@ -847,7 +847,7 @@ class EmbeddingModule1DIndices(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, indices):
|
||||
return self.embed.forward(indices)
|
||||
|
@ -870,7 +870,7 @@ class SoftmaxIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -892,7 +892,7 @@ class _SoftmaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._softmax(tensor, 0, False)
|
||||
|
@ -916,7 +916,7 @@ class SoftmaxIntNegDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -940,7 +940,7 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
@ -962,7 +962,7 @@ class _LogSoftmaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
|
||||
|
@ -984,7 +984,7 @@ class _LogSoftmaxModuleStable(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten._log_softmax(tensor, dim=0, half_to_float=False)
|
||||
|
@ -1009,7 +1009,7 @@ class SoftplusModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.softplus(x)
|
||||
|
@ -1031,7 +1031,7 @@ class HardsigmoidModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardsigmoid(x)
|
||||
|
@ -1053,7 +1053,7 @@ class HardsigmoidRandomModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardsigmoid(x)
|
||||
|
@ -1075,7 +1075,7 @@ class BroadcastToModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 1], torch.float32, True),
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.broadcast_to(x, [1, -1, -1, 4])
|
||||
|
@ -1144,7 +1144,7 @@ class RollModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, -9223372036854775808, 2], torch.float32, True),
|
||||
([3, -1, 2], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.roll([2, -1], [0, 2])
|
||||
|
@ -1186,7 +1186,7 @@ class ExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 1], torch.float32, True),
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.expand([1, -1, -1, 4])
|
||||
|
@ -1208,7 +1208,7 @@ class ContiguousModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.contiguous()
|
||||
|
@ -1231,7 +1231,7 @@ class LogSoftmaxIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return self.log_softmax.forward(tensor)
|
||||
|
@ -1296,9 +1296,9 @@ class ReturnThreeTensorFloat32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
return a, b, c
|
||||
|
@ -1320,9 +1320,9 @@ class AddCMulModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, tensor1, tensor2):
|
||||
return torch.addcmul(input, tensor1, tensor2, value=1.0)
|
||||
|
@ -1344,9 +1344,9 @@ class AddCDivModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, tensor1, tensor2):
|
||||
return torch.addcdiv(input, tensor1, tensor2, value=1.0)
|
||||
|
@ -1412,7 +1412,7 @@ class DropoutEvalIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.dropout(x, 0.2, train=False)
|
||||
|
@ -1434,7 +1434,7 @@ class DropoutEvalFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.dropout(x, 0.1, train=False)
|
||||
|
@ -1456,7 +1456,7 @@ class DropoutTrainModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
res = torch.dropout(x, 0.3, train=True)
|
||||
|
@ -1479,7 +1479,7 @@ class NumelModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.numel(input)
|
||||
|
@ -1523,7 +1523,7 @@ class BoolTensorReturnFalseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return a
|
||||
|
@ -1545,7 +1545,7 @@ class BoolTensorReturnTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return a
|
||||
|
@ -1567,7 +1567,7 @@ class BoolTensorReturnMixedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return a
|
||||
|
@ -1589,8 +1589,8 @@ class BoolTensorHandleSignless(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return a * b
|
||||
|
@ -1614,7 +1614,7 @@ class TModuleRank2(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.t(lhs)
|
||||
|
@ -1636,7 +1636,7 @@ class TModuleRank1(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.t(lhs)
|
||||
|
@ -1726,8 +1726,8 @@ class ReturnTwoTensorF32I64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return a, b
|
||||
|
@ -1749,8 +1749,8 @@ class IndexTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
return torch.ops.aten.index(x, (index, ))
|
||||
|
@ -1772,8 +1772,8 @@ class IndexTensorModule3dInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
return torch.ops.aten.index(x, (index,))
|
||||
|
@ -1795,8 +1795,8 @@ class IndexTensorSelectDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, ind):
|
||||
return torch.ops.aten.index(a, (None, ind, None))
|
||||
|
@ -1817,7 +1817,7 @@ class IndexTensorMultiInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([3, 3], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
|
@ -1841,7 +1841,7 @@ class IndexTensorMultiInputOneDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([6, 1], torch.int64, True),
|
||||
([3], torch.int64, True),
|
||||
])
|
||||
|
@ -1865,9 +1865,9 @@ class IndexTensorMultiInputContiguousOneDimDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 1], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
|
@ -1895,9 +1895,9 @@ class IndexTensorMultiInputNonContiguousOneDimDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 1], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
|
@ -1926,9 +1926,9 @@ class IndexTensorMultiInputNonContiguousDynamic(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, 2], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, 2], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2):
|
||||
return torch.ops.aten.index(x, (
|
||||
|
@ -1956,10 +1956,10 @@ class IndexTensorMultiInputNonContiguousMultipleStaticDims(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([4, 1], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
([-9223372036854775808, 3], torch.int64, True),
|
||||
([-1, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2, index3):
|
||||
return torch.ops.aten.index(x, (index1, index2, index3))
|
||||
|
@ -1984,7 +1984,7 @@ class IndexTensorMultiInputNonContiguous(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([4, 2], torch.int64, True),
|
||||
([4, 2], torch.int64, True),
|
||||
])
|
||||
|
@ -2008,7 +2008,7 @@ class IndexTensorMultiInputThreeIndexers(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
([8, 4, 2], torch.int64, True),
|
||||
([8, 1, 1], torch.int64, True),
|
||||
([4, 2], torch.int64, True),
|
||||
|
@ -2036,7 +2036,7 @@ class IndexTensorMultiInputContiguousCenter(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([2, 2], torch.int64, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
@ -2060,8 +2060,8 @@ class IndexTensorHackedTwinModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
return torch.ops.aten.index(x, [index])
|
||||
|
@ -2083,8 +2083,8 @@ class IndexTensorHackedTwinModule3dInput(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index):
|
||||
return torch.ops.aten.index(x, [index])
|
||||
|
@ -2108,10 +2108,10 @@ class IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims(
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([4, 1], torch.int64, True),
|
||||
([1, 3], torch.int64, True),
|
||||
([-9223372036854775808, 3], torch.int64, True),
|
||||
([-1, 3], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, index1, index2, index3):
|
||||
return torch.ops.aten.index(x, [index1, index2, index3])
|
||||
|
@ -2137,7 +2137,7 @@ class SquareModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.square(x)
|
||||
|
@ -2159,7 +2159,7 @@ class HardswishModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardswish(x)
|
||||
|
@ -2181,7 +2181,7 @@ class HardswishRandomModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardswish(x)
|
||||
|
@ -2203,7 +2203,7 @@ class SiluModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.silu(x)
|
||||
|
@ -2225,7 +2225,7 @@ class HardTanhModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2)
|
||||
|
@ -2247,7 +2247,7 @@ class HardTanhIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2)
|
||||
|
@ -2269,7 +2269,7 @@ class BincountModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x)
|
||||
|
@ -2313,7 +2313,7 @@ class BincountMinlengthModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.bincount(x, minlength=600)
|
||||
|
@ -2335,8 +2335,8 @@ class ExpandAsFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 1, 1], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 1, 1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.expand_as(x, y)
|
||||
|
@ -2356,7 +2356,7 @@ class ExpandAsIntModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
([1, 1, 1], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.expand_as(x, y)
|
||||
|
@ -2379,8 +2379,8 @@ class CopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2399,8 +2399,8 @@ class CopyWithDifferentSizesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 4], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, 1], torch.float32, True),
|
||||
([-1, -1, 4], torch.float32, True),
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2419,8 +2419,8 @@ class CopyWithDifferentDTypesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2439,8 +2439,8 @@ class CopyWithDifferentDTypesAndSizesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 4], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, 1], torch.int64, True),
|
||||
([-1, -1, 4], torch.float32, True),
|
||||
([-1, -1, 1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.copy_(x, y)
|
||||
|
@ -2463,7 +2463,7 @@ class ToCopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x)
|
||||
|
@ -2482,7 +2482,7 @@ class ToCopyWithDTypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.int64)
|
||||
|
@ -2501,7 +2501,7 @@ class ToCopyWithDTypeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten._to_copy(x, dtype=torch.int64, pin_memory=False)
|
||||
|
@ -2543,7 +2543,7 @@ class FlipModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.flip(x, [1, 2])
|
||||
|
@ -2565,7 +2565,7 @@ class DetachModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.detach(x)
|
||||
|
@ -2650,9 +2650,9 @@ class BaddbmmDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2)
|
||||
|
@ -2692,9 +2692,9 @@ class BaddbmmDifferentDtypesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2)
|
||||
|
@ -2714,9 +2714,9 @@ class BaddbmmWithAlphaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, alpha=5)
|
||||
|
@ -2735,9 +2735,9 @@ class BaddbmmWithBetaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=0.5)
|
||||
|
@ -2756,9 +2756,9 @@ class BaddbmmWithAlphaBetaModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, batch1, batch2):
|
||||
return torch.ops.aten.baddbmm(input, batch1, batch2, beta=6, alpha=2.4)
|
||||
|
@ -2841,7 +2841,7 @@ class NumpyTRankNDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.ops.aten.numpy_T(lhs)
|
||||
|
@ -2860,7 +2860,7 @@ class NumpyTRank2Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.ops.aten.numpy_T(lhs)
|
||||
|
@ -2879,7 +2879,7 @@ class NumpyTRank1Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.ops.aten.numpy_T(lhs)
|
||||
|
@ -2916,9 +2916,9 @@ class AtenEmbeddingBagSumExample(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, weight, indices, offsets):
|
||||
return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None)
|
||||
|
@ -2938,9 +2938,9 @@ class Aten_EmbeddingBagExample(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, weight, indices, offsets):
|
||||
return torch.ops.aten._embedding_bag(weight, indices, offsets)
|
||||
|
@ -2962,7 +2962,7 @@ class CumsumModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, val):
|
||||
return torch.ops.aten.cumsum(val, 1)
|
||||
|
@ -2997,7 +2997,7 @@ class AtenToDeviceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
|
@ -3018,7 +3018,7 @@ class UpSampleNearest2dBackward(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
|
@ -3041,7 +3041,7 @@ class UpSampleNearest2dBackwardScalesNone(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d_backward(input,
|
||||
|
|
|
@ -37,7 +37,7 @@ class TensorToInt(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return int(x)
|
||||
|
@ -75,7 +75,7 @@ class TensorToFloat(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return float(x)
|
||||
|
@ -113,7 +113,7 @@ class TensorToBool(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return bool(x)
|
||||
|
|
|
@ -319,7 +319,7 @@ class EmptyLikeDefaultDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a).fill_(0)
|
||||
|
@ -338,7 +338,7 @@ class EmptyLikeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a, dtype=torch.int32).fill_(0)
|
||||
|
@ -357,7 +357,7 @@ class EmptyLikeMemoryFormatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a,
|
||||
|
@ -377,7 +377,7 @@ class EmptyLikeFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a, dtype=torch.float32).fill_(0)
|
||||
|
@ -396,7 +396,7 @@ class EmptyLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.empty_like(a, dtype=torch.float64,
|
||||
|
@ -419,7 +419,7 @@ class ZerosLikeDefaultDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a)
|
||||
|
@ -438,7 +438,7 @@ class ZerosLikeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a, dtype=torch.int32)
|
||||
|
@ -457,7 +457,7 @@ class ZerosLikeFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a, dtype=torch.float32)
|
||||
|
@ -476,7 +476,7 @@ class ZerosLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.zeros_like(a, dtype=torch.float64, pin_memory=False)
|
||||
|
@ -498,7 +498,7 @@ class OnesLikeDefaultDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a)
|
||||
|
@ -517,7 +517,7 @@ class OnesLikeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a, dtype=torch.int32)
|
||||
|
@ -536,7 +536,7 @@ class OnesLikeFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a, dtype=torch.float32)
|
||||
|
@ -555,7 +555,7 @@ class OnesLikeFalsePinMemoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ones_like(a, dtype=torch.float64, pin_memory=False)
|
||||
|
@ -577,7 +577,7 @@ class NewZerosModuleDefaultDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4])
|
||||
|
@ -596,7 +596,7 @@ class NewZerosModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.int64)
|
||||
|
@ -615,7 +615,7 @@ class NewZerosModuleInt3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.int64)
|
||||
|
@ -634,7 +634,7 @@ class NewZerosModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4], dtype=torch.float32)
|
||||
|
@ -653,7 +653,7 @@ class NewZerosModuleFloat3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4, 5], dtype=torch.float32)
|
||||
|
@ -672,7 +672,7 @@ class NewZerosModuleFalsePinMemory(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_zeros(a, [3, 4],
|
||||
|
@ -696,7 +696,7 @@ class NewOnesModuleDefaultDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4])
|
||||
|
@ -715,7 +715,7 @@ class NewOnesModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.int64)
|
||||
|
@ -734,7 +734,7 @@ class NewOnesModuleInt3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.int64)
|
||||
|
@ -753,7 +753,7 @@ class NewOnesModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4], dtype=torch.float32)
|
||||
|
@ -772,7 +772,7 @@ class NewOnesModuleFloat3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4, 5], dtype=torch.float32)
|
||||
|
@ -791,7 +791,7 @@ class NewOnesModuleFalsePinMemory(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_ones(a, [3, 4],
|
||||
|
@ -929,7 +929,7 @@ class FullLikeModuleDefaultDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 5)
|
||||
|
@ -948,7 +948,7 @@ class FullLikeModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 10.5)
|
||||
|
@ -967,7 +967,7 @@ class FullLikeModuleInt3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 5.0, dtype=torch.int64)
|
||||
|
@ -1005,7 +1005,7 @@ class FullLikeModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 10)
|
||||
|
@ -1024,7 +1024,7 @@ class FullLikeModuleFloat3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a, 15, dtype=torch.float32)
|
||||
|
@ -1062,7 +1062,7 @@ class FullLikeModuleFalsePinMemory(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.full_like(a,
|
||||
|
@ -1087,7 +1087,7 @@ class ZeroFloat32Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.zero_(tensor)
|
||||
|
@ -1106,7 +1106,7 @@ class ZeroInt32Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.zero_(tensor)
|
||||
|
@ -1125,7 +1125,7 @@ class ZeroInt64Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.zero_(tensor)
|
||||
|
@ -1147,7 +1147,7 @@ class NewEmptyModuleDefaultDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
@ -1166,7 +1166,7 @@ class NewEmptyModuleInt2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], dtype=torch.int64).fill_(0)
|
||||
|
@ -1185,7 +1185,7 @@ class NewEmptyModuleInt3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4, 5],
|
||||
|
@ -1205,7 +1205,7 @@ class NewEmptyModuleFloat2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4],
|
||||
|
@ -1225,7 +1225,7 @@ class NewEmptyModuleFloat3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4, 5],
|
||||
|
@ -1245,7 +1245,7 @@ class NewEmptyModuleFalsePinMemory(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4],
|
||||
|
@ -1266,7 +1266,7 @@ class NewEmptyModuleNonDefaultFloatDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
@ -1286,7 +1286,7 @@ class NewEmptyModuleNonDefaultIntDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4]).fill_(0)
|
||||
|
@ -1305,7 +1305,7 @@ class NewEmptyModuleLayoutIntDtype(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.new_empty(a, [3, 4], layout=0).fill_(0)
|
||||
|
@ -1327,8 +1327,8 @@ class MaskedFillScalarDefaultModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, mask):
|
||||
return torch.ops.aten.masked_fill(x, mask, value=0.5)
|
||||
|
@ -1348,8 +1348,8 @@ class MaskedFillScalarIntValueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, mask):
|
||||
return torch.ops.aten.masked_fill(x, mask, value=5)
|
||||
|
@ -1369,8 +1369,8 @@ class MaskedFillScalarFloatValueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, mask):
|
||||
return torch.ops.aten.masked_fill(x, mask, value=-0.01)
|
||||
|
@ -1390,8 +1390,8 @@ class MaskedFillTensorFloatValueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, mask, value):
|
||||
|
|
|
@ -20,7 +20,7 @@ class TorchPrimLoopForLikeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
|
@ -42,7 +42,7 @@ class TorchPrimLoopWhileLikeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True)
|
||||
([-1, -1], torch.int64, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
x_val = x.size(0)
|
||||
|
|
|
@ -22,7 +22,7 @@ class Conv2dNoPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -45,7 +45,7 @@ class Conv2dBiasNoPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -68,7 +68,7 @@ class Conv2dWithPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -97,7 +97,7 @@ class Conv2dWithPaddingDilationStrideModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
@ -149,8 +149,8 @@ class Convolution2DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -200,8 +200,8 @@ class Convolution2DStridedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -225,8 +225,8 @@ class _Convolution2DAllFalseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -254,8 +254,8 @@ class _Convolution2DBenchmarkModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -283,8 +283,8 @@ class _Convolution2DDeterministicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -312,8 +312,8 @@ class _Convolution2DCudnnModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -341,8 +341,8 @@ class _Convolution2DTF32Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -370,8 +370,8 @@ class _ConvolutionDeprecated2DAllFalseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -398,8 +398,8 @@ class _ConvolutionDeprecated2DBenchmarkModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -426,8 +426,8 @@ class _ConvolutionDeprecated2DDeterministicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -454,8 +454,8 @@ class _ConvolutionDeprecated2DCudnnModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten._convolution(inputVec,
|
||||
|
@ -482,8 +482,8 @@ class ConvolutionModule2DGroups(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -510,8 +510,8 @@ class ConvolutionModule2DTranspose(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -537,8 +537,8 @@ class ConvolutionModule2DTransposeStrided(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.convolution(inputVec,
|
||||
|
@ -592,8 +592,8 @@ class Conv_Transpose2dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec, weight):
|
||||
return torch.ops.aten.conv_transpose2d(inputVec,
|
||||
|
@ -619,7 +619,7 @@ class UpSampleNearest2d(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.upsample_nearest2d(input,
|
||||
|
@ -640,7 +640,7 @@ class UpSampleNearest2dSameSize(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
|
@ -660,7 +660,7 @@ class UpSampleNearest2dDiffSize(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[8, 11],
|
||||
|
@ -679,7 +679,7 @@ class UpSampleNearest2dDiffFactor(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1, -1, -1], torch.float32, True)])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
output_size=[6, 10],
|
||||
|
@ -700,7 +700,7 @@ class UpSampleNearest2dSameFactor(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, inputVec):
|
||||
return torch._C._nn.upsample_nearest2d(inputVec,
|
||||
|
|
|
@ -24,7 +24,7 @@ class CustomOpExampleModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops._torch_mlir_custom_op_example.identity(a)
|
||||
|
|
|
@ -27,7 +27,7 @@ class ElementwiseUnaryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.tanh(a)
|
||||
|
@ -49,7 +49,7 @@ class ElementwiseUnaryIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.tanh(a)
|
||||
|
@ -71,8 +71,8 @@ class ElementwiseBinaryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return a * b
|
||||
|
@ -118,9 +118,9 @@ class ElementwiseTernaryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
return torch.lerp(a, b, c)
|
||||
|
@ -166,9 +166,9 @@ class ElementwiseWhereSelfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
return torch.where(a > 0.5, b, c)
|
||||
|
@ -190,7 +190,7 @@ class ElementwiseWhereScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.where(a > 0.5, 4.0, 8.0)
|
||||
|
@ -212,8 +212,8 @@ class ElementwiseWhereScalarOtherModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, b, 8.0)
|
||||
|
@ -235,8 +235,8 @@ class ElementwiseWhereScalarSelfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.where(a > 0.5, 4.0, b)
|
||||
|
@ -260,7 +260,7 @@ class ElementwiseAddModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -283,7 +283,7 @@ class ElementwiseUnsqueezeBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -307,7 +307,7 @@ class ElementwiseUnsqueezeNegDimsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
# As mentioned in `unsqueeze` docstring,
|
||||
|
@ -332,7 +332,7 @@ class ElementwiseFlattenBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -355,7 +355,7 @@ class ElementwiseReluModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.relu(x)
|
||||
|
@ -377,7 +377,7 @@ class ElementwiseRelu6Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.relu6(x)
|
||||
|
@ -399,7 +399,7 @@ class ElementwiseLeakyReluModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.leaky_relu(x, negative_slope=0.1)
|
||||
|
@ -422,7 +422,7 @@ class ElementwiseGeluModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.gelu(x)
|
||||
|
@ -444,7 +444,7 @@ class ElementwiseSigmoidModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x)
|
||||
|
@ -466,7 +466,7 @@ class ElementwiseSigmoidIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sigmoid(x)
|
||||
|
@ -488,8 +488,8 @@ class ElementwiseMinimumModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.minimum(x, y)
|
||||
|
@ -511,8 +511,8 @@ class ElementwiseMinimumIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.minimum(x, y)
|
||||
|
@ -534,8 +534,8 @@ class ElementwiseMaximumModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.maximum(x, y)
|
||||
|
@ -557,8 +557,8 @@ class ElementwiseMaximumIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.maximum(x, y)
|
||||
|
@ -580,7 +580,7 @@ class ElementwiseClampModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
float_min = torch.clamp(x, min=-2.0)
|
||||
|
@ -607,7 +607,7 @@ class ElementwiseClampMinModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
float_min = torch.ops.aten.clamp_min(x, min=-2.0)
|
||||
|
@ -632,7 +632,7 @@ class ElementwiseClampMaxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
float_max = torch.ops.aten.clamp_max(x, max=2.0)
|
||||
|
@ -657,7 +657,7 @@ class RsubFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 3.0, alpha=1.0)
|
||||
|
@ -679,7 +679,7 @@ class RsubFloatModule_noalpha(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2.0)
|
||||
|
@ -701,7 +701,7 @@ class RsubIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2, alpha=3)
|
||||
|
@ -723,7 +723,7 @@ class RsubIntModule_noalpha(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.rsub(x, 2.)
|
||||
|
@ -745,7 +745,7 @@ class ElementwiseMulScalarIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 4)
|
||||
|
@ -767,7 +767,7 @@ class ElementwiseMulScalarFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 100.0)
|
||||
|
@ -789,7 +789,7 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 8.0)
|
||||
|
@ -811,8 +811,8 @@ class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
@ -834,8 +834,8 @@ class ElementwiseMulTensorIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.mul(a, b)
|
||||
|
@ -858,7 +858,7 @@ class ElementwiseMishModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mish(x)
|
||||
|
@ -880,8 +880,8 @@ class ElementwiseAtan2TensorFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
@ -903,8 +903,8 @@ class ElementwiseAtan2TensorIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
@ -927,8 +927,8 @@ class ElementwiseAtan2FloatIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.atan2(a, b)
|
||||
|
@ -951,7 +951,7 @@ class ElementwiseLogModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log(a)
|
||||
|
@ -973,7 +973,7 @@ class ElementwiseLogIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log(a)
|
||||
|
@ -994,7 +994,7 @@ class ElementwiseLog1pModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log1p(a)
|
||||
|
@ -1016,7 +1016,7 @@ class ElementwiseErfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.erf(a)
|
||||
|
@ -1038,7 +1038,7 @@ class ElementwiseErfIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.erf(a)
|
||||
|
@ -1060,7 +1060,7 @@ class ElementwiseSqrtModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
@ -1082,7 +1082,7 @@ class ElementwiseSqrtIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sqrt(a)
|
||||
|
@ -1104,7 +1104,7 @@ class ElementwiseFloorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.floor(a)
|
||||
|
@ -1126,7 +1126,7 @@ class ElementwiseCeilModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ceil(a)
|
||||
|
@ -1148,7 +1148,7 @@ class ElementwisePowModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.pow(a, 2.0)
|
||||
|
@ -1170,8 +1170,8 @@ class ElementwisePowTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.pow(a, b)
|
||||
|
@ -1193,8 +1193,8 @@ class ElementwisePowTensorBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 1], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.pow(a, b)
|
||||
|
@ -1214,7 +1214,7 @@ class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -1233,7 +1233,7 @@ class ElementwiseToDtypeIdentityModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32, False, False)
|
||||
|
||||
|
@ -1254,7 +1254,7 @@ class ElementwiseLog2Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log2(a)
|
||||
|
@ -1276,7 +1276,7 @@ class ElementwiseLog2IntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log2(a)
|
||||
|
@ -1298,7 +1298,7 @@ class ElementwiseRsqrtModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
@ -1320,7 +1320,7 @@ class ElementwiseRsqrtIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.rsqrt(a)
|
||||
|
@ -1342,7 +1342,7 @@ class ElementwiseAbsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.abs(a)
|
||||
|
@ -1364,7 +1364,7 @@ class ElementwiseReciprocalModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
@ -1386,7 +1386,7 @@ class ElementwiseReciprocalIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.reciprocal(a)
|
||||
|
@ -1408,7 +1408,7 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.div(x, 10.0)
|
||||
|
@ -1429,7 +1429,7 @@ class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2.0)
|
||||
|
@ -1451,7 +1451,7 @@ class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2.0)
|
||||
|
@ -1472,7 +1472,7 @@ class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2)
|
||||
|
@ -1492,7 +1492,7 @@ class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2)
|
||||
|
@ -1514,8 +1514,8 @@ class ElementwiseDivTensorFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b)
|
||||
|
@ -1537,8 +1537,8 @@ class ElementwiseDivRoundingModeTruncModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
|
@ -1558,8 +1558,8 @@ class ElementwiseDivRoundingModeFloorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.div(a, b, rounding_mode="floor")
|
||||
|
@ -1582,8 +1582,8 @@ class ElementwiseAndIntegerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.bitwise_and(x, y)
|
||||
|
@ -1607,8 +1607,8 @@ class ElementwiseOrIntegerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.bitwise_or(x, y)
|
||||
|
@ -1632,7 +1632,7 @@ class ElementwiseNotIntegerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bitwise_not(x)
|
||||
|
@ -1654,7 +1654,7 @@ class ElementwiseNotInt32Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bitwise_not(x)
|
||||
|
@ -1676,7 +1676,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sub(x, 2.1, alpha=2)
|
||||
|
@ -1698,7 +1698,7 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sub(x, 2.1)
|
||||
|
@ -1720,7 +1720,7 @@ class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
@ -1742,7 +1742,7 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
@ -1764,7 +1764,7 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.add(x, 3.0, alpha=2)
|
||||
|
@ -1786,7 +1786,7 @@ class ElementwiseCloneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.clone(x)
|
||||
|
@ -1808,7 +1808,7 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.clone(x, memory_format=torch.contiguous_format)
|
||||
|
@ -1830,7 +1830,7 @@ class LiftFreshCopyModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.lift_fresh_copy(x)
|
||||
|
@ -1852,7 +1852,7 @@ class ElementwiseExpModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
@ -1874,7 +1874,7 @@ class ElementwiseExpIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.exp(a)
|
||||
|
@ -1896,7 +1896,7 @@ class ElementwiseExpm1Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.special.expm1(a)
|
||||
|
@ -1918,7 +1918,7 @@ class ElementwiseExpm1IntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.special.expm1(a)
|
||||
|
@ -1940,7 +1940,7 @@ class ElementwiseSinModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
@ -1962,7 +1962,7 @@ class ElementwiseSinIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sin(a)
|
||||
|
@ -1984,7 +1984,7 @@ class ElementwiseCosModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
@ -2006,7 +2006,7 @@ class ElementwiseCosIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.cos(a)
|
||||
|
@ -2028,7 +2028,7 @@ class ElementwiseNegModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.neg(a)
|
||||
|
@ -2047,8 +2047,8 @@ class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-1], torch.bool, True),
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2064,8 +2064,8 @@ class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2083,8 +2083,8 @@ class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.bool, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2102,8 +2102,8 @@ class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.bool, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2121,8 +2121,8 @@ class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2140,8 +2140,8 @@ class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2159,8 +2159,8 @@ class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2178,8 +2178,8 @@ class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_or(x, y)
|
||||
|
@ -2200,8 +2200,8 @@ class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.floor_divide(x, y)
|
||||
|
@ -2220,8 +2220,8 @@ class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.floor_divide(x, y)
|
||||
|
@ -2244,7 +2244,7 @@ class AtenTriuModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x)
|
||||
|
@ -2266,7 +2266,7 @@ class AtenTriuWithPosDiagonalModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=2)
|
||||
|
@ -2288,7 +2288,7 @@ class AtenTriuWithNegDiagonalModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.triu(x, diagonal=-4)
|
||||
|
@ -2309,7 +2309,7 @@ class AtenRoundFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.round(x)
|
||||
|
@ -2327,7 +2327,7 @@ class AtenRoundIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.round(x)
|
||||
|
@ -2349,7 +2349,7 @@ class Fill_TensorFloat64WithFloat32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
@ -2368,7 +2368,7 @@ class Fill_TensorFloat64WithFloat64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3.0)
|
||||
|
@ -2387,7 +2387,7 @@ class Fill_TensorFloat64WithInt64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor):
|
||||
return torch.ops.aten.fill_(tensor, 3)
|
||||
|
@ -2409,7 +2409,7 @@ class Fill_TensorFloat32WithFloat32(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
|
@ -2428,7 +2428,7 @@ class Fill_TensorFloat32WithFloat64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
|
@ -2447,7 +2447,7 @@ class Fill_TensorFloat32WithInt64(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
def forward(self, tensor, value):
|
||||
|
|
|
@ -18,7 +18,7 @@ class ElementwiseGtFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 0.6)
|
||||
|
@ -37,7 +37,7 @@ class ElementwiseGtIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 10)
|
||||
|
@ -56,7 +56,7 @@ class ElementwiseGtMixed2ScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.gt(x, 7)
|
||||
|
@ -75,7 +75,7 @@ class ElementwiseGeFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 0.6)
|
||||
|
@ -94,7 +94,7 @@ class ElementwiseGeIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 10)
|
||||
|
@ -113,7 +113,7 @@ class ElementwiseGeMixedIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
@ -132,7 +132,7 @@ class ElementwiseGeFloatIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ge(x, 7)
|
||||
|
@ -151,8 +151,8 @@ class ElementwiseGtFloatTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
@ -171,8 +171,8 @@ class ElementwiseGtIntTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.gt(x, y)
|
||||
|
@ -191,7 +191,7 @@ class ElementwiseLtFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0.6)
|
||||
|
@ -210,7 +210,7 @@ class ElementwiseLtIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 0)
|
||||
|
@ -229,7 +229,7 @@ class ElementwiseLtDiffWidthScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.lt(x, 2)
|
||||
|
@ -249,7 +249,7 @@ class ElementwiseLeFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 0.6)
|
||||
|
@ -268,7 +268,7 @@ class ElementwiseLeIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 10)
|
||||
|
@ -287,7 +287,7 @@ class ElementwiseLeMixedIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
@ -306,7 +306,7 @@ class ElementwiseLeFloatIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.le(x, 7)
|
||||
|
@ -325,8 +325,8 @@ class ElementwiseLtFloatTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
@ -345,8 +345,8 @@ class ElementwiseLtIntTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.lt(x, y)
|
||||
|
@ -365,7 +365,7 @@ class ElementwiseEqFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 6.0)
|
||||
|
@ -385,7 +385,7 @@ class ElementwiseEqIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
@ -404,7 +404,7 @@ class ElementwiseEqDiffWidthScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.eq(x, 2)
|
||||
|
@ -424,8 +424,8 @@ class ElementwiseEqFloatTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
@ -446,8 +446,8 @@ class ElementwiseEqIntTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.eq(x, y)
|
||||
|
@ -466,7 +466,7 @@ class ElementwiseNeFloatScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 2.0)
|
||||
|
@ -486,7 +486,7 @@ class ElementwiseNeIntScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ne(x, 3)
|
||||
|
|
|
@ -47,9 +47,9 @@ class HistogramBinningCalibrationByFeature(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, segment_value, segment_lengths, logit):
|
||||
origin_prediction = torch.sigmoid(
|
||||
|
|
|
@ -20,9 +20,9 @@ class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -45,9 +45,9 @@ class IndexPutImpl2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -70,9 +70,9 @@ class IndexPutImpl3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -99,9 +99,9 @@ class IndexPutImpl1DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -128,9 +128,9 @@ class IndexPutImpl1DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -153,9 +153,9 @@ class IndexPutImpl2DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
|
||||
|
@ -178,9 +178,9 @@ class IndexPutImpl3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input.clone(), (index, ),
|
||||
|
@ -207,9 +207,9 @@ class IndexPutImpl1DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten._index_put_impl_(input, (index, ),
|
||||
|
@ -235,9 +235,9 @@ class IndexPut1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -259,9 +259,9 @@ class IndexPut2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -283,9 +283,9 @@ class IndexPut3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -311,9 +311,9 @@ class IndexPut1DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -335,9 +335,9 @@ class IndexPut2DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -359,9 +359,9 @@ class IndexPut3DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -386,9 +386,9 @@ class IndexPut1DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -409,9 +409,9 @@ class IndexPut2DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -432,9 +432,9 @@ class IndexPut3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -459,9 +459,9 @@ class IndexPut1DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -483,9 +483,9 @@ class IndexPut2DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -507,9 +507,9 @@ class IndexPut3DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, (index, ),
|
||||
|
@ -535,9 +535,9 @@ class IndexPutHackedTwin1DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -559,9 +559,9 @@ class IndexPutHackedTwin2DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -583,9 +583,9 @@ class IndexPutHackedTwin3DFloatNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -611,9 +611,9 @@ class IndexPutHackedTwin1DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -636,9 +636,9 @@ class IndexPutHackedTwin2DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -661,9 +661,9 @@ class IndexPutHackedTwin3DIntNonAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index],
|
||||
|
@ -689,9 +689,9 @@ class IndexPutHackedTwin1DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -711,9 +711,9 @@ class IndexPutHackedTwin2DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -733,9 +733,9 @@ class IndexPutHackedTwin3DFloatAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -759,9 +759,9 @@ class IndexPutHackedTwin1DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -782,9 +782,9 @@ class IndexPutHackedTwin2DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
@ -805,9 +805,9 @@ class IndexPutHackedTwin3DIntAccumulateModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, input, index, value):
|
||||
return torch.ops.aten.index_put(input, [index], value, accumulate=True)
|
||||
|
|
|
@ -95,8 +95,8 @@ class IndexSelectDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
|
@ -114,7 +114,7 @@ class IndexSelectDynamicInputSizeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([2], torch.int64, True),
|
||||
])
|
||||
|
||||
|
@ -134,7 +134,7 @@ class IndexSelectDynamicIndexSizeModule(torch.nn.Module):
|
|||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input, indices):
|
||||
|
|
|
@ -18,8 +18,8 @@ class MatmulDot(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -38,8 +38,8 @@ class Matmul2D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -58,8 +58,8 @@ class MatmulVecMat(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -78,8 +78,8 @@ class MatmulMatVec(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -98,8 +98,8 @@ class Matmul3D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -118,8 +118,8 @@ class Matmul4d(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -178,8 +178,8 @@ class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -198,8 +198,8 @@ class MatmulBroadcastBatchDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
@ -216,8 +216,8 @@ class Mv(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, m, v):
|
||||
return torch.mv(m, v)
|
||||
|
|
|
@ -24,7 +24,7 @@ class Mlp1LayerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
@ -46,7 +46,7 @@ class Mlp2LayerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
|
@ -70,7 +70,7 @@ class Mlp2LayerModuleNoBias(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x = self.tanh0(self.fc0(x))
|
||||
|
@ -91,7 +91,7 @@ class BatchMlpLayerModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.tanh0(self.fc0(x))
|
||||
|
|
|
@ -20,8 +20,8 @@ class NllLossModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
|
@ -44,8 +44,8 @@ class NllLossModule_mean(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
|
@ -68,8 +68,8 @@ class NllLossModule_sum(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
def forward(self, x, y):
|
||||
|
@ -92,7 +92,7 @@ class NllLossModule_1D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
# Here the 2nd index is ignored.
|
||||
|
@ -117,8 +117,8 @@ class NllLossModule_ignore_index_out_of_bounds(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
# None of the index is ignored here, since the ignored index is out of bounds.
|
||||
def forward(self, x, y):
|
||||
|
@ -141,9 +141,9 @@ class NllLossModule_backward(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -170,10 +170,10 @@ class NllLossModule_backwardWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
@ -201,9 +201,9 @@ class NllLossModule_backward_ignore_index(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -231,9 +231,9 @@ class NllLossModule_backwardMean(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -260,10 +260,10 @@ class NllLossModule_backwardMeanWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
@ -290,9 +290,9 @@ class NllLossModule_backwardSum(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -319,10 +319,10 @@ class NllLossModule_backwardSumWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
@ -349,9 +349,9 @@ class NllLossModule_backward1D(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -378,10 +378,10 @@ class NllLossModule_backward1DWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
@ -408,9 +408,9 @@ class NllLossModule_backward1DMean(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -437,10 +437,10 @@ class NllLossModule_backward1DMeanWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
@ -467,9 +467,9 @@ class NllLossModule_backward1DSum(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, total_weight):
|
||||
|
@ -496,10 +496,10 @@ class NllLossModule_backward1DSumWeight(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, grad_output, input, target, weight, total_weight):
|
||||
|
|
|
@ -122,11 +122,11 @@ class NativeBatchNorm1DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
|
@ -148,11 +148,11 @@ class NativeBatchNorm2DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
|
@ -174,11 +174,11 @@ class NativeBatchNorm3DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
|
@ -200,10 +200,10 @@ class NativeBatchNormNoneWeightModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, bias, running_mean, running_var):
|
||||
return torch.ops.aten.native_batch_norm(
|
||||
|
@ -245,9 +245,9 @@ class NativeLayerNormDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, weight, bias):
|
||||
list = [2, 2, 3]
|
||||
|
|
|
@ -43,7 +43,7 @@ class AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
@ -86,7 +86,7 @@ class AdaptiveAvgPool2dUnitOutputSizeDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.aap2d(x)
|
||||
|
@ -113,7 +113,7 @@ class MaxPool2dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
@ -160,7 +160,7 @@ class MaxPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.mp2d(x)
|
||||
|
@ -182,7 +182,7 @@ class MaxPool2dWithIndicesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -205,7 +205,7 @@ class MaxPool2dWithIndicesFullSizeKernelModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -229,7 +229,7 @@ class MaxPool2dWithIndicesNonDefaultPaddingModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -253,7 +253,7 @@ class MaxPool2dWithIndicesNonDefaultStrideModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -277,7 +277,7 @@ class MaxPool2dWithIndicesNonDefaultDilationModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -301,7 +301,7 @@ class MaxPool2dWithIndicesNonDefaultParamsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -325,7 +325,7 @@ class MaxPool2dWithIndicesAllNegativeValuesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -372,7 +372,7 @@ class MaxPool2dWithIndicesAllOnesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -395,7 +395,7 @@ class MaxPool2dWithIndicesCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.max_pool2d_with_indices(x,
|
||||
|
@ -483,9 +483,9 @@ class MaxPool2dWithIndicesBackwardDynamic4DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, output, input, indices):
|
||||
kernel_size = [2, 2]
|
||||
|
@ -513,9 +513,9 @@ class MaxPool2dWithIndicesBackwardDynamic3DModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, output, input, indices):
|
||||
kernel_size = [2, 2]
|
||||
|
@ -552,7 +552,7 @@ class AvgPool2dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
@ -575,7 +575,7 @@ class AvgPool2dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
@ -650,7 +650,7 @@ class AvgPool2dCeilModeTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return self.ap2d(x)
|
||||
|
|
|
@ -18,7 +18,7 @@ class ReduceSumFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -37,7 +37,7 @@ class ReduceSumDtypeFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dtype=torch.float32)
|
||||
|
@ -56,7 +56,7 @@ class ReduceSumElementTypeBoolModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -75,7 +75,7 @@ class ReduceSumDimIntListFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1))
|
||||
|
@ -94,7 +94,7 @@ class ReduceSumDimIntListDtypeFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1), dtype=torch.float32)
|
||||
|
@ -113,7 +113,7 @@ class ReduceSumDimIntListKeepDimFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (1, 2), keepdim=True)
|
||||
|
@ -151,7 +151,7 @@ class ReduceSumDimIntListEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dim=[])
|
||||
|
@ -170,7 +170,7 @@ class ReduceSumDimIntListElementTypeBoolModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dim=(-1), keepdim=False)
|
||||
|
@ -189,7 +189,7 @@ class ReduceSumUnsignedIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -208,7 +208,7 @@ class ReduceSumSignedIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a)
|
||||
|
@ -227,7 +227,7 @@ class ReduceSumDtypeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dtype=torch.int64)
|
||||
|
@ -246,7 +246,7 @@ class ReduceSumDimIntListIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1))
|
||||
|
@ -265,7 +265,7 @@ class ReduceSumDimIntListDtypeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-1, -1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1), dtype=torch.int64)
|
||||
|
@ -284,7 +284,7 @@ class ReduceSumDimIntListKeepDimIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (1, 2), keepdim=True)
|
||||
|
@ -303,7 +303,7 @@ class ReduceMaxAlongDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
@ -322,7 +322,7 @@ class ReduceMaxAlongDimNegative(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1)[0]
|
||||
|
@ -341,7 +341,7 @@ class ReduceMaxKeepDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)[1]
|
||||
|
@ -360,7 +360,7 @@ class ReduceMaxKeepDimReturnBoth(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, 1, keepdim=True)
|
||||
|
@ -379,7 +379,7 @@ class ReduceMaxAllDims(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
@ -397,7 +397,7 @@ class ReduceMaxNegativeDim(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a, -1, keepdim=True)
|
||||
|
@ -415,7 +415,7 @@ class ReduceMaxFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
@ -433,7 +433,7 @@ class ReduceMaxSignedIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
@ -451,7 +451,7 @@ class ReduceMaxUnsignedIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.max(a)
|
||||
|
@ -469,7 +469,7 @@ class ReduceL1NormModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1)
|
||||
|
@ -487,7 +487,7 @@ class ReduceL1NormWithDTypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=1, dtype=torch.float64)
|
||||
|
@ -505,7 +505,7 @@ class ReduceL2NormModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0)
|
||||
|
@ -523,7 +523,7 @@ class ReduceLN3NormModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=0, ord=-3)
|
||||
|
@ -541,7 +541,7 @@ class ReduceL3NormAllDimsModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, dim=None, ord=3)
|
||||
|
@ -559,7 +559,7 @@ class ReduceL3NormKeepDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.linalg.vector_norm(a, keepdim=True, ord=3)
|
||||
|
@ -576,7 +576,7 @@ class ReduceFrobeniusNormModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=False)
|
||||
|
@ -593,7 +593,7 @@ class ReduceFrobeniusNormKeepDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.ops.aten.frobenius_norm(a, dim=[0, 1], keepdim=True)
|
||||
|
@ -611,8 +611,8 @@ class MseLossNoReductionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
|
@ -630,8 +630,8 @@ class MseLossMeanReductionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
|
@ -649,8 +649,8 @@ class MseLossSumReductionWithDifferentElemTypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808 , -9223372036854775808], torch.float64, True),
|
||||
([-1 , -1], torch.float32, True),
|
||||
([-1 , -1], torch.float64, True),
|
||||
])
|
||||
|
||||
def forward(self, x, y):
|
||||
|
|
|
@ -112,7 +112,7 @@ class ViewDynamicExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
|
||||
([-1, -1, 30, 384], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -131,7 +131,7 @@ class ViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -150,7 +150,7 @@ class ViewCollapseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -169,7 +169,7 @@ class ViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
@ -228,7 +228,7 @@ class ViewDynamicExpandCollapseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 4, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 4, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -247,7 +247,7 @@ class ViewDynamicExpandCollapseWithAtenIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -342,11 +342,11 @@ class View1DFoldModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-9223372036854775808)
|
||||
return a.view(-1)
|
||||
|
||||
@register_test_case(module_factory=lambda: View1DFoldModule())
|
||||
def View1DFoldModule_basic(module, tu: TestUtils):
|
||||
|
@ -365,7 +365,7 @@ class ViewCollapseInferredDimModule(torch.nn.Module):
|
|||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(-9223372036854775808, 4)
|
||||
return a.view(-1, 4)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewCollapseInferredDimModule())
|
||||
def ViewCollapseInferredDimModule_basic(module, tu: TestUtils):
|
||||
|
@ -384,7 +384,7 @@ class ViewExpandInferredDimModule(torch.nn.Module):
|
|||
])
|
||||
|
||||
def forward(self, a):
|
||||
return a.view(3, -9223372036854775808, 2)
|
||||
return a.view(3, -1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: ViewExpandInferredDimModule())
|
||||
def ViewExpandInferredDimModule_basic(module, tu: TestUtils):
|
||||
|
@ -399,7 +399,7 @@ class ViewExpandDynamicDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -9223372036854775808, 128], torch.float32, True),
|
||||
([1, -1, 128], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -418,7 +418,7 @@ class ViewFlattenAndExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -456,7 +456,7 @@ class UnsafeViewDynamicExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, 30, 384], torch.float32, True),
|
||||
([-1, -1, 30, 384], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -475,7 +475,7 @@ class UnsafeViewDynamicExpandWithAtenSizeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -494,7 +494,7 @@ class UnsafeViewCollapseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -513,7 +513,7 @@ class UnsafeViewCollapseDynamicWithAtenSizeIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1, -1, -1], torch.float32, True),
|
||||
([], torch.int64, True),
|
||||
([], torch.int64, True),
|
||||
])
|
||||
|
@ -534,11 +534,11 @@ class UnsafeView1DFoldModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
return torch.ops.aten._unsafe_view(a, [-9223372036854775808])
|
||||
return torch.ops.aten._unsafe_view(a, [-1])
|
||||
|
||||
@register_test_case(module_factory=lambda: UnsafeView1DFoldModule())
|
||||
def UnsafeView1DFoldModule_basic(module, tu: TestUtils):
|
||||
|
@ -553,7 +553,7 @@ class ReshapeExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -572,7 +572,7 @@ class ReshapeCollapseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -591,7 +591,7 @@ class ViewNoChange1dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -609,7 +609,7 @@ class ViewNoChange2dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -627,7 +627,7 @@ class ViewNoChange3dModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -666,7 +666,7 @@ class ReshapeAliasExpandModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
@ -685,7 +685,7 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, a):
|
||||
|
|
|
@ -20,11 +20,11 @@ class TestMultipleTensorReturn(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a, b, c, d, e):
|
||||
return a, b, c, d, e
|
||||
|
@ -48,9 +48,9 @@ class TestMultipleTensorAndPrimitiveTypesReturn(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.bool, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, a, b, c):
|
||||
d = 1
|
||||
|
|
|
@ -14,9 +14,9 @@ class UniformModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
|
@ -89,9 +89,9 @@ class BernoulliModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.bernoulli(x)
|
||||
|
@ -126,7 +126,7 @@ class BernoulliZerosModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
@ -145,7 +145,7 @@ class BernoulliOnesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.bernoulli(x)
|
||||
|
@ -164,9 +164,9 @@ class BernoulliFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.bernoulli_(x, 0.4)
|
||||
|
@ -201,12 +201,12 @@ class BernoulliTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, px, y, py, z, pz):
|
||||
a = torch.ops.aten.bernoulli_(x, px)
|
||||
|
@ -244,7 +244,7 @@ class RandLikeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.rand_like(x)
|
||||
|
@ -265,7 +265,7 @@ class RandLikeDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
a = torch.ops.aten.rand_like(x, dtype=torch.float32)
|
||||
|
|
|
@ -18,7 +18,7 @@ class SliceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:1, 1:3:1, 2:4:1]
|
||||
|
@ -58,7 +58,7 @@ class SliceOutOfUpperBoundIndexModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
|
@ -80,7 +80,7 @@ class SliceOutOfLowerBoundEndIndexModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:-8,-7:,:]
|
||||
|
@ -99,7 +99,7 @@ class SliceOutOfLowerBoundStartIndexModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[-8:3:1, 1:3:1, 2:4:1]
|
||||
|
@ -119,7 +119,7 @@ class SliceEndSleStartModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
|
@ -142,7 +142,7 @@ class SliceStartEqEndModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
# TODO: remove hacky cat tensor once refbackend supports 0 size dim
|
||||
|
@ -164,7 +164,7 @@ class SliceSizeTwoStepModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0:5:2, 0:3:2, 0:4:2]
|
||||
|
@ -183,7 +183,7 @@ class SliceNegIdxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:-1, -2:-1]
|
||||
|
@ -202,7 +202,7 @@ class SliceSingleIdxModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[0]
|
||||
|
@ -221,7 +221,7 @@ class SliceWholeTensorModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x[:, :]
|
||||
|
@ -240,7 +240,7 @@ class SelectIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return x.select(0,0)
|
||||
|
@ -261,8 +261,8 @@ class SliceScatterModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 1)
|
||||
|
@ -278,8 +278,8 @@ class SliceScatterZeroDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 0, end = 1, step = 1)
|
||||
|
@ -298,8 +298,8 @@ class SliceScatterNegativeDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x,
|
||||
|
@ -321,8 +321,8 @@ class SliceScatterStepVariationModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.slice_scatter(x, src, dim = 1, start = 0, end = 1, step = 2)
|
||||
|
@ -357,8 +357,8 @@ class SelectScatterModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, src):
|
||||
return torch.ops.aten.select_scatter(x, src, dim = 0, index = 0)
|
||||
|
@ -395,7 +395,7 @@ class NarrowHorizontalTest(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.narrow(x, dim=0, start=0, length=2)
|
||||
|
@ -415,7 +415,7 @@ class NarrowVerticalTest(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.narrow(x, dim=1, start=0, length=2)
|
||||
|
@ -434,7 +434,7 @@ class NarrowHorizontalTest2(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.narrow(x, dim=0, start=0, length=2)
|
||||
|
@ -454,7 +454,7 @@ class NarrowVerticalTest2(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.narrow(x, dim=1, start=0, length=2)
|
||||
|
|
|
@ -63,7 +63,7 @@ class SqueezeBroadcastModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -108,7 +108,7 @@ class SqueezeDimDynamicModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
|
||||
([-1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 4)
|
||||
|
@ -130,7 +130,7 @@ class SqueezeDimNegDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([1, -9223372036854775808, 1, 384, -9223372036854775808, 1], torch.float32, True),
|
||||
([1, -1, 1, 384, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, -6)
|
||||
|
@ -152,7 +152,7 @@ class SqueezeDimIdentityModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 1, -9223372036854775808], torch.float32, True),
|
||||
([4, 1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.squeeze(a, 0)
|
||||
|
|
|
@ -37,7 +37,7 @@ class MeanDynamicSizesModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
@ -56,7 +56,7 @@ class MeanDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dtype=torch.float32)
|
||||
|
@ -75,7 +75,7 @@ class MeanLargeInputModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x)
|
||||
|
@ -94,7 +94,7 @@ class MeanDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
@ -113,7 +113,7 @@ class MeanDimLargeInputModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 2))
|
||||
|
@ -133,7 +133,7 @@ class MeanDimDtypeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0,), dtype=torch.float32)
|
||||
|
@ -152,7 +152,7 @@ class MeanDimKeepdimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (1, 2), keepdim=True)
|
||||
|
@ -171,7 +171,7 @@ class MeanDimAllReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2))
|
||||
|
@ -190,7 +190,7 @@ class MeanDimAllReduceKeepdimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (0, 1, 2), keepdim=True)
|
||||
|
@ -209,7 +209,7 @@ class MeanDimNegativeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, (-1, 1))
|
||||
|
@ -229,7 +229,7 @@ class MeanDimEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dim=[])
|
||||
|
@ -248,7 +248,7 @@ class MeanDimNoneDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.mean(x, dim=None)
|
||||
|
@ -267,7 +267,7 @@ class VarUnbiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=True)
|
||||
|
@ -285,7 +285,7 @@ class VarBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, unbiased=False)
|
||||
|
@ -303,7 +303,7 @@ class StdUnbiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=True)
|
||||
|
@ -321,7 +321,7 @@ class StdBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, unbiased=False)
|
||||
|
@ -342,7 +342,7 @@ class StdDimKeepDimFalseModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(1, 2), keepdim=False)
|
||||
|
@ -364,7 +364,7 @@ class StdDimKeepDimTrueModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(0, 1, 2), keepdim=True)
|
||||
|
@ -386,7 +386,7 @@ class StdDimBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=(0, 2), unbiased=False)
|
||||
|
@ -408,7 +408,7 @@ class StdDimEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=[], keepdim=False)
|
||||
|
@ -430,7 +430,7 @@ class StdDimNoneDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.std(x, dim=None, keepdim=False)
|
||||
|
@ -452,7 +452,7 @@ class VarDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 2), keepdim=True)
|
||||
|
@ -474,7 +474,7 @@ class VarDimUnbiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 2), unbiased=True, keepdim=True)
|
||||
|
@ -496,7 +496,7 @@ class VarDimBiasedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0,1), unbiased=False, keepdim=True)
|
||||
|
@ -518,7 +518,7 @@ class VarDimSingleDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0,), keepdim=True)
|
||||
|
@ -540,7 +540,7 @@ class VarDimMultiDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float64, True),
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[0, 2], keepdim=False)
|
||||
|
@ -562,7 +562,7 @@ class VarDimAllDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(0, 1, 2), keepdim=True)
|
||||
|
@ -584,7 +584,7 @@ class VarDimNegativeModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=(-1, 1), keepdim=True)
|
||||
|
@ -606,7 +606,7 @@ class VarDimEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[], keepdim=False)
|
||||
|
@ -628,7 +628,7 @@ class VarDimNoneDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, keepdim=False)
|
||||
|
@ -650,7 +650,7 @@ class VarCorrectionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=2)
|
||||
|
@ -672,7 +672,7 @@ class VarCorrectionSingleDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[1], correction=1)
|
||||
|
@ -694,7 +694,7 @@ class VarCorrectionAllDimReduceModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x,
|
||||
|
@ -719,7 +719,7 @@ class VarCorrectionKeepDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[0, 1], correction=None, keepdim=True)
|
||||
|
@ -741,7 +741,7 @@ class VarCorrectionNoneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=None, correction=None)
|
||||
|
@ -763,7 +763,7 @@ class VarCorrectionEmptyDimModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[], correction=2)
|
||||
|
@ -785,7 +785,7 @@ class VarCorrectionLargeInputModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var(x, dim=[2, 3], correction=2)
|
||||
|
@ -807,7 +807,7 @@ class VarMeanCorrectionModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean(x, dim=[1, 2], correction=2, keepdim=True)
|
||||
|
@ -829,7 +829,7 @@ class VarMeanCorrectionNoneModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean(x, dim=None, correction=None, keepdim=False)
|
||||
|
|
|
@ -19,7 +19,7 @@ class Threshold1dIntI32Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -37,7 +37,7 @@ class Threshold1dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -55,7 +55,7 @@ class Threshold2dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -73,7 +73,7 @@ class Threshold3dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -91,7 +91,7 @@ class Threshold1dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -109,7 +109,7 @@ class Threshold2dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -127,7 +127,7 @@ class Threshold3dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
|
@ -145,8 +145,8 @@ class ThresholdBackward1dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -164,8 +164,8 @@ class ThresholdBackward2dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -183,8 +183,8 @@ class ThresholdBackward3dIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -202,8 +202,8 @@ class ThresholdBackward1dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -221,8 +221,8 @@ class ThresholdBackward2dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -240,8 +240,8 @@ class ThresholdBackward3dFloatModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -259,8 +259,8 @@ class ThresholdBackward1dMixedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -278,8 +278,8 @@ class ThresholdBackward2dMixedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
@ -297,8 +297,8 @@ class ThresholdBackward3dMixedModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808, -9223372036854775808], torch.int64, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.int64, True),
|
||||
])
|
||||
|
||||
def forward(self, grad, input):
|
||||
|
|
|
@ -18,7 +18,7 @@ class TypeConversionF32ToF64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float64)
|
||||
|
||||
|
@ -34,7 +34,7 @@ class TypeConversionF64ToF32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float64, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float64, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
@ -50,7 +50,7 @@ class TypeConversionI32ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.int32, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class TypeConversionI64ToI32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.int64, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
@ -82,7 +82,7 @@ class TypeConversionI1ToI32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.bool, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
@ -99,7 +99,7 @@ class TypeConversionI1ToI64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.bool, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.int64)
|
||||
|
||||
|
@ -116,7 +116,7 @@ class TypeConversionI1ToF32Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.bool, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float32)
|
||||
|
||||
|
@ -133,7 +133,7 @@ class TypeConversionI1ToF64Module(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.bool, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.bool, True)])
|
||||
def forward(self, x):
|
||||
return x.to(torch.float64)
|
||||
|
||||
|
@ -153,7 +153,7 @@ class ToDtypeLayoutNoneModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.float64,
|
||||
|
@ -176,7 +176,7 @@ class ToDtypeLayoutStridedModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.to(x,
|
||||
dtype=torch.float64,
|
||||
|
@ -224,8 +224,8 @@ class TypeAsSameModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.type_as(x, y)
|
||||
|
@ -245,7 +245,7 @@ class PrimsConvertElementTypeModule(torch.nn.Module):
|
|||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-9223372036854775808, -9223372036854775808], torch.float32, True)])
|
||||
@annotate_args([None, ([-1, -1], torch.float32, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.prims.convert_element_type(x, dtype=torch.int64)
|
||||
|
||||
|
|
|
@ -19,8 +19,8 @@ class TypePromotionSameCategoryDifferentWidthModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int32, True),
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int32, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
@ -41,8 +41,8 @@ class TypePromotionDifferentCategoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
return torch.add(a, b, alpha=3)
|
||||
|
@ -61,7 +61,7 @@ class TypePromotionSameCategoryZeroRankWiderModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float64, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -81,7 +81,7 @@ class TypePromotionZeroRankHigherCategoryModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
@ -101,7 +101,7 @@ class TypePromotionAlphaWiderModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808], torch.float32, True),
|
||||
([-1], torch.float32, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, a, b):
|
||||
|
|
|
@ -24,7 +24,7 @@ class ResNet18Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.resnet.forward(img)
|
||||
|
@ -64,8 +64,8 @@ class IouOfModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, bbox1, bbox2):
|
||||
area1 = (bbox1[:, 2] - bbox1[:, 0]) * (bbox1[:, 3] - bbox1[:, 1])
|
||||
|
@ -95,7 +95,7 @@ class MobilenetV2Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.mobilenetv2.forward(img)
|
||||
|
@ -118,7 +118,7 @@ class MobilenetV3Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-9223372036854775808, 3, -9223372036854775808, -9223372036854775808], torch.float32, True),
|
||||
([-1, 3, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, img):
|
||||
return self.mobilenetv3.forward(img)
|
||||
|
|
|
@ -18,7 +18,7 @@ func.func @torch.aten.view$twotothree(%arg0: !torch.vtensor<[3,2],f32>) -> !torc
|
|||
%0 = torch.prim.ListConstruct %int2, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[3,2],f32>, !torch.list<int> -> !torch.vtensor<[2,3],f32>
|
||||
return %1 : !torch.vtensor<[2,3],f32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamictest(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
|
@ -35,7 +35,7 @@ func.func @torch.aten.view$dynamictest(%arg0: !torch.vtensor<[?,?],f32>) -> !tor
|
|||
%2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||
%3 = torch.aten.view %arg0, %2 : !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?,?],f32>
|
||||
return %3 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$dynamicVal(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,?,128],f32>) -> !torch.vtensor<[16,1,128],f32> {
|
||||
|
@ -54,7 +54,7 @@ func.func @torch.aten.view$dynamicVal(%arg0: !torch.vtensor<[1,?,128],f32>) -> !
|
|||
%0 = torch.prim.ListConstruct %int16, %int1, %int128 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,?,128],f32>, !torch.list<int> -> !torch.vtensor<[16,1,128],f32>
|
||||
return %1 : !torch.vtensor<[16,1,128],f32>
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.view$expandInferredDim(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[2,6],f32>) -> !torch.vtensor<[3,2,2],f32> {
|
||||
|
@ -73,4 +73,4 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) -
|
|||
%0 = torch.prim.ListConstruct %int3, %int2, %int-1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[2,6],f32>, !torch.list<int> -> !torch.vtensor<[3,2,2],f32>
|
||||
return %1 : !torch.vtensor<[3,2,2],f32>
|
||||
}
|
||||
}
|
|
@ -212,7 +212,7 @@ func.func @torch.aten.div$basic(%arg0: !torch.vtensor<[?, ?],f32>, %arg1: !torch
|
|||
// CHECK: %[[ARG2_BUILTIN:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.constant.none
|
||||
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||
// CHECK: %[[RESHAPE_SUM:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESHAPE_SUM:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[CONST:.*]] = "tosa.const"() {value = dense<-1.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.mul"(%[[RESHAPE_SUM]], %[[CONST]]) {shift = 0 : i32} : (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
|
@ -236,7 +236,7 @@ func.func @test_reduce_mean_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !
|
|||
// CHECK: %[[ARG3:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ARG3_BUILTIN:.*]] = torch.prim.ListConstruct %[[ARG3]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[SUM:.*]] = "tosa.reduce_sum"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xf32>) -> tensor<1x?x?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[SUM]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],f32>
|
||||
func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -292,7 +292,7 @@ func.func @test_reduce_all$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.
|
|||
// CHECK: %[[ARG1:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ARG2:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[REDUCE:.*]] = "tosa.reduce_any"(%[[ARG0_BUILTIN]]) {axis = 0 : i64} : (tensor<?x?x?x?xi1>) -> tensor<1x?x?x?xi1>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-9223372036854775808, -9223372036854775808, -9223372036854775808]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
|
||||
// CHECK: %[[RESULT_BUILTIN:.*]] = "tosa.reshape"(%[[REDUCE]]) {new_shape = [-1, -1, -1]} : (tensor<1x?x?x?xi1>) -> tensor<?x?x?xi1>
|
||||
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[RESULT_BUILTIN]] : tensor<?x?x?xi1> -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: return %[[RESULT]] : !torch.vtensor<[?,?,?],i1>
|
||||
func.func @test_reduce_any_dim$basic(%arg0: !torch.vtensor<[?,?,?,?],i1>) -> !torch.vtensor<[?,?,?],i1> {
|
||||
|
@ -479,7 +479,7 @@ func.func @torch.aten.eq.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !
|
|||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int -1
|
||||
// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-9223372036854775808]} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [-1]} : (tensor<?x?x?x?xf32>) -> tensor<?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?],f32>
|
||||
// CHECK: }
|
||||
|
|
|
@ -202,8 +202,8 @@ func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST:.*]]-9223372036854775808 = torch.constant.int -9223372036854775808
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[CST]]-9223372036854775808 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[CST:.*]]-1 = torch.constant.int -1
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[CST]]-1 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[FLATTEN:.*]] = torch.aten.view %[[INP]], %[[T0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] :
|
||||
|
@ -1009,8 +1009,8 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten
|
|||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT:.*]]-9223372036854775808 = torch.constant.int -9223372036854775808
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[INT]]-9223372036854775808 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[INT:.*]]-1 = torch.constant.int -1
|
||||
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[INT]]-1 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: %[[T1:.*]] = torch.aten.view %[[ARG0]], %[[T0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list<int> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[T1]] : !torch.vtensor<[?],f32>
|
||||
func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||
|
|
Loading…
Reference in New Issue