Revert "Update Torch ODS list with new ops (#2361)"

This reverts commit 3dd29f9d5d.
revert-2361-glebk/new_ops
powderluv 2023-08-21 16:46:02 -07:00 committed by GitHub
parent 3dd29f9d5d
commit d3fb3fb461
5 changed files with 85 additions and 1055 deletions

View File

@ -41,8 +41,7 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then
ext_module="${TORCH_MLIR_EXT_MODULES}"
fi
set +u
PYTHONPATH="${PYTHONPATH}:${pypath}" python \
PYTHONPATH="${pypath}" python \
-m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \
--torch_ir_include_dir="${torch_ir_include_dir}" \
--pytorch_op_extensions="${ext_module}" \

View File

@ -1278,6 +1278,8 @@ LTC_XFAIL_SET = {
"BoolIntTrueModule_basic",
"CeilFloatModule_basic",
"DivFloatModule_basic",
"ElementwiseAtenFloorDivideBroadcastModule_basic",
"ElementwiseAtenFloorDivideModule_basic",
"EqIntModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
@ -1285,6 +1287,7 @@ LTC_XFAIL_SET = {
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
"HardtanhBackward_basic",
"IndexPut1DFloatAccumulateModule_basic",
"IndexPut1DFloatNonAccumulateModule_basic",
"IndexPut1DIntAccumulateModule_basic",
@ -1348,6 +1351,8 @@ LTC_XFAIL_SET = {
"NeFloatIntModule_basic",
"NeIntModule_basic",
"QuantizedMLP_basic",
"RandLikeDtypeModule_basic",
"RandLikeModule_basic",
"RollModule_basic",
"ScalarImplicitFloatModule_basic",
"ScalarImplicitIntModule_basic",
@ -1366,14 +1371,20 @@ LTC_XFAIL_SET = {
"TensorToIntZeroRank_basic",
"TensorToInt_basic",
"UniformModule_basic",
"UniformNoCorrelationModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"AtenEmbeddingBagSumExample_basic",
"Aten_EmbeddingBagExample_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseRemainderScalarModule_Bool_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",
"UpSampleNearest2dBackwardVec_basic",
"UpSampleNearest2dBackwardOutputSizeNone_basic",
"ConvolutionBackwardModule2D_basic",
@ -1395,8 +1406,24 @@ LTC_XFAIL_SET = {
"NativeDropoutTrainStaticShapeModule_basic",
"StdCorrectionKeepDimModule_basic",
"StdCorrectionNoneModule_basic",
"VarBiasedModule_basic",
"VarCorrectionAllDimReduceModule_basic",
"VarCorrectionEmptyDimModule_basic",
"VarCorrectionKeepDimModule_basic",
"VarCorrectionLargeInputModule_basic",
"VarCorrectionModule_basic",
"VarCorrectionNoneModule_basic",
"VarCorrectionSingleDimReduceModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimEmptyDimModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimNoneDimModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"AtenFloatScalarModule_basic",
"PrimsSqueezeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",

View File

@ -113,57 +113,6 @@ def Torch_AtenHardtanh_Op : Torch_Op<"aten.hardtanh_", [
}];
}
def Torch_AtenEluOp : Torch_Op<"aten.elu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha,
AnyTorchScalarType:$scale,
AnyTorchScalarType:$input_scale
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenEluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenElu_Op : Torch_Op<"aten.elu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::elu_ : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$alpha,
AnyTorchScalarType:$scale,
AnyTorchScalarType:$input_scale
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenElu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenElu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenReluOp : Torch_Op<"aten.relu", [
AllowsTypeRefinement,
HasValueSemantics,
@ -436,51 +385,6 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [
}];
}
def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sgn : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSgnOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSgnOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSgn_Op : Torch_Op<"aten.sgn_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::sgn_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSgn_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSgn_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [
AllowsTypeRefinement,
HasValueSemantics,
@ -616,51 +520,6 @@ def Torch_AtenErf_Op : Torch_Op<"aten.erf_", [
}];
}
def Torch_AtenErfinvOp : Torch_Op<"aten.erfinv", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::erfinv : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenErfinvOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenErfinvOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenErfinv_Op : Torch_Op<"aten.erfinv_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::erfinv_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenErfinv_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenErfinv_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSiluOp : Torch_Op<"aten.silu", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2341,53 +2200,6 @@ def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [
}];
}
def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$min
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenClampMinTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$min
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2435,53 +2247,6 @@ def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [
}];
}
def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$max
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenLog2Op : Torch_Op<"aten.log2", [
AllowsTypeRefinement,
HasValueSemantics,
@ -3691,30 +3456,6 @@ def Torch_AtenMishOp : Torch_Op<"aten.mish", [
}];
}
def Torch_AtenXlogyTensorOp : Torch_Op<"aten.xlogy.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenXlogyTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenXlogyTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
@ -4226,32 +3967,6 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [
}];
}
def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$num_samples,
Torch_BoolType:$replacement,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMultinomialOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenMultinomialOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
AllowsTypeRefinement,
HasValueSemantics,
@ -5285,32 +5000,6 @@ def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [
}];
}
def Torch_AtenNormalFunctionalOp : Torch_Op<"aten.normal_functional", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$mean,
Torch_FloatType:$std,
AnyTorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNormalFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenNormalFunctionalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
AllowsTypeRefinement,
HasValueSemantics,
@ -6451,31 +6140,6 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [
}];
}
def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$A,
Torch_StringType:$mode
);
let results = (outs
AnyTorchTensorType:$Q,
AnyTorchTensorType:$R
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinalgQrOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenLinalgQrOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}
def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
AllowsTypeRefinement,
HasValueSemantics,
@ -6678,159 +6342,6 @@ def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [
}];
}
def Torch_AtenBinaryCrossEntropyOp : Torch_Op<"aten.binary_cross_entropy", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
AnyTorchOptionalTensorType:$weight,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBinaryCrossEntropyOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenBinaryCrossEntropyOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$target,
AnyTorchOptionalTensorType:$weight,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenBinaryCrossEntropyBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenBinaryCrossEntropyBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$output,
AnyTorchTensorType:$buffer
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogSigmoidForwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 2);
}
void AtenLogSigmoidForwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 2);
}
}];
}
def Torch_AtenLogSigmoidBackwardOp : Torch_Op<"aten.log_sigmoid_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$self,
AnyTorchTensorType:$buffer
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenLogSigmoidBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenSigmoidBackwardOp : Torch_Op<"aten.sigmoid_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchTensorType:$output
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenSigmoidBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input1,
AnyTorchTensorType:$input2,
AnyTorchTensorType:$target,
Torch_FloatType:$margin,
Torch_IntType:$reduction
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCosineEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenCosineEmbeddingLossOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7158,61 +6669,6 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
}];
}
def Torch_AtenEyeOp : Torch_Op<"aten.eye", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$n,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEyeOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenEyeOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenEyeMOp : Torch_Op<"aten.eye.m", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$n,
Torch_IntType:$m,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEyeMOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenEyeMOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7410,31 +6866,6 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [
}];
}
def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::all.dim : (Tensor, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAllDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAllDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenAnyOp : Torch_Op<"aten.any", [
AllowsTypeRefinement,
HasValueSemantics,
@ -7616,31 +7047,6 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
}];
}
def Torch_AtenArgminOp : Torch_Op<"aten.argmin", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::argmin : (Tensor, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchOptionalIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenArgminOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenArgminOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
AllowsTypeRefinement,
HasValueSemantics,
@ -8731,80 +8137,6 @@ def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
}];
}
def Torch_AtenMinOp : Torch_Op<"aten.min", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::min : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMinOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenMinOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenMinDimOp : Torch_Op<"aten.min.dim", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$values,
AnyTorchTensorType:$indices
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMinDimOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 2);
}
void AtenMinDimOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 2);
}
}];
}
def Torch_AtenAminOp : Torch_Op<"aten.amin", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::amin : (Tensor, int[], bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
Torch_BoolType:$keepdim
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAminOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenAminOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [
AllowsTypeRefinement,
ReadOnly
@ -9694,58 +9026,6 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [
}];
}
def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFmodTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenFmodTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_BoolType:$return_inverse,
Torch_BoolType:$return_counts,
AnyTorchOptionalIntType:$dim
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenUniqueConsecutiveOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 3);
}
void AtenUniqueConsecutiveOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 3);
}
}];
}
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
AllowsTypeRefinement,
HasValueSemantics,
@ -10186,60 +9466,6 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [
}];
}
def Torch_AtenIm2colOp : Torch_Op<"aten.im2col", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$kernel_size,
AnyTorchListOfTorchIntType:$dilation,
AnyTorchListOfTorchIntType:$padding,
AnyTorchListOfTorchIntType:$stride
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIm2colOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIm2colOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenScatterReduceOp : Torch_Op<"aten.scatter.reduce", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$src,
Torch_StringType:$reduce
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenScatterReduceOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenScatterReduceOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}
def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [
AllowsTypeRefinement,
HasValueSemantics,
@ -11455,30 +10681,6 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
}];
}
def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRemainderTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenRemainderTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
AllowsTypeRefinement,
HasValueSemantics,
@ -12736,34 +11938,6 @@ def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward",
}];
}
def Torch_AtenEluBackwardOp : Torch_Op<"aten.elu_backward", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$grad_output,
AnyTorchScalarType:$alpha,
AnyTorchScalarType:$scale,
AnyTorchScalarType:$input_scale,
Torch_BoolType:$is_result,
AnyTorchTensorType:$self_or_result
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEluBackwardOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
}
void AtenEluBackwardOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -8,7 +8,6 @@
//===----------------------------------------------------------------------===//
#include <ATen/ATen.h>
#include <ATen/ops/where.h>
#include <c10/util/Optional.h>
#include <cmath>
@ -18,85 +17,47 @@
namespace torch {
namespace lazy {
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the
// future.
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future.
std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor& self,
const at::Scalar& other,
const at::Scalar& alpha) {
std::vector<torch::lazy::Shape>
compute_shape_div(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor& self,
const at::Scalar& other,
const at::Scalar& alpha) {
std::vector<torch::lazy::Shape>
compute_shape_mse_loss_backward(
const at::Tensor& grad_output,
const at::Tensor& self,
const at::Tensor& target,
int64_t reduction) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_mse_loss_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Tensor& target, int64_t reduction) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor& self,
const at::Scalar& other) {
std::vector<torch::lazy::Shape>
compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_var(
const at::Tensor& self, at::OptionalIntArrayRef dim,
const c10::optional<at::Scalar> & correction, bool keepdim) {
c10::optional<int64_t> correction, bool keepdim) {
// Result of variance is scalar tensor.
return {Shape(self.scalar_type(), {})};
}
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
const at::Tensor& self, const at::Scalar& min_val,
const at::Scalar& max_val) {
const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val
) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_hardtanh_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Scalar& min_val, const at::Scalar& max_val) {
std::vector<torch::lazy::Shape> compute_shape_where(
const at::Tensor & condition,
const at::Tensor & self,
const at::Tensor & other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_where(const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
// There are cases like -
// torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>,
// !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>.
// So the result tensor would the biggest of all the three operands.
auto condition_meta = at::native::empty_strided_meta_symint(
condition.sym_sizes(), condition.sym_strides(),
/*dtype=*/c10::make_optional(condition.scalar_type()),
/*layout=*/c10::make_optional(condition.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(), self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto other_meta = at::native::empty_strided_meta_symint(
other.sym_sizes(), other.sym_strides(),
/*dtype=*/c10::make_optional(other.scalar_type()),
/*layout=*/c10::make_optional(other.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::where(condition_meta, self_meta, other_meta);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bucketize(
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
bool right) {
@ -104,64 +65,50 @@ std::vector<torch::lazy::Shape> compute_shape_bucketize(
return {Shape(dtype, self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_copy(const at::Tensor& self,
const at::Tensor& src,
bool non_blocking) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_floor_divide(
const at::Tensor& self, const at::Tensor& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_fmod(const at::Tensor& self,
const at::Scalar& other) {
std::vector<torch::lazy::Shape> compute_shape_copy(
const at::Tensor& self,
const at::Tensor& src,
bool non_blocking) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps) {
const at::Tensor& input,
const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias,
int64_t N, int64_t C, int64_t HxW,
int64_t group, double eps) {
TORCH_CHECK(input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
// A separate mean and var needs to be kept for each group per N.
shapes.emplace_back(at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
shapes.emplace_back(at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{N, group});
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_im2col(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(), self.sym_strides(),
/*dtype=*/c10::make_optional(self.scalar_type()),
/*layout=*/c10::make_optional(self.layout()),
/*device=*/c10::make_optional(c10::Device(c10::kMeta)),
/*pin_memory=*/c10::nullopt);
auto out_meta = at::im2col(self_meta, kernel_size, dilation, padding, stride);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,
const at::Tensor& rstd, const c10::optional<at::Tensor>& weight, int64_t N,
int64_t C, int64_t HxW, int64_t group, ::std::array<bool, 3> output_mask) {
const at::Tensor& grad_out,
const at::Tensor& input,
const at::Tensor& mean,
const at::Tensor& rstd,
const c10::optional<at::Tensor>& weight,
int64_t N, int64_t C, int64_t HxW,
int64_t group, ::std::array<bool, 3> output_mask) {
TORCH_CHECK(input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
TORCH_CHECK(
input.sizes().size() >= 2,
"Input tensor must have at least batch and channel dimensions!");
std::vector<torch::lazy::Shape> shapes;
shapes.reserve(3);
shapes.emplace_back(input.scalar_type(), input.sizes().vec());
@ -169,102 +116,15 @@ std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
int64_t num_features = input.size(1);
// `weight` and `bias` are vectors of length C (number of channels)`
shapes.emplace_back(at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
shapes.emplace_back(
at::get_default_dtype_as_scalartype(),
std::vector<int64_t>{num_features});
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_remainder(
const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_uniform(
const at::Tensor& self, double from, double to,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_normal_functional(
const at::Tensor& self, double mean, double std,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
c10::optional<at::Generator> generator) {
// Input tensor can be either 1D or 2D. The last dim of output
// should be 'num_samples'. So the output shape can be either
// [num_samples] or [m, num_samples].
// Output type can only be long tensor.
auto ishape = self.sizes().vec();
ishape.back() = num_samples;
return {Shape(at::kLong, ishape)};
}
std::vector<torch::lazy::Shape> compute_shape_eye(
int64_t n, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_eye(
int64_t n, int64_t m, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_full(
at::IntArrayRef size, const at::Scalar& fill_value,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
const at::Tensor& value) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_randn(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_randint(
int64_t high, at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_randint(
int64_t low, int64_t high, at::IntArrayRef size,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bernoulli(
const at::Tensor& self, const at::Tensor &p,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
} // namespace lazy
} // namespace torch
} // namespace lazy
} // namespace torch

View File

@ -242,18 +242,15 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
for key in [
"aten::tanh : (Tensor) -> (Tensor)",
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)",
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sign : (Tensor) -> (Tensor)",
"aten::sgn : (Tensor) -> (Tensor)",
"aten::hardsigmoid : (Tensor) -> (Tensor)",
"aten::hardswish : (Tensor) -> (Tensor)",
"aten::erf : (Tensor) -> (Tensor)",
"aten::erfinv : (Tensor) -> (Tensor)",
"aten::silu : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
@ -290,9 +287,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)",
"aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)",
"aten::clamp_min : (Tensor, Scalar) -> (Tensor)",
"aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
"aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",
"aten::sqrt : (Tensor) -> (Tensor)",
"aten::log1p : (Tensor) -> (Tensor)",
@ -314,18 +309,17 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
# variants.
emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::mish : (Tensor) -> (Tensor)")
emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
@ -350,7 +344,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)")
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
@ -400,9 +393,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit(
"aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)"
)
emit(
"aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)",
)
emit(
"aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)"
)
@ -462,7 +452,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)")
emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)")
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
@ -471,12 +460,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::nonzero : (Tensor) -> (Tensor)")
emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])")
emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)")
emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)")
emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)")
emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)")
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
@ -492,8 +475,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
@ -502,7 +483,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::isnan : (Tensor) -> (Tensor)")
emit("aten::all : (Tensor) -> (Tensor)")
emit("aten::all.bool : (bool[]) -> (bool)")
emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::any : (Tensor) -> (Tensor)")
emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)")
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
@ -510,7 +490,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
@ -552,9 +531,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::max : (Tensor) -> (Tensor)")
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::amax : (Tensor, int[], bool) -> (Tensor)")
emit("aten::min : (Tensor) -> (Tensor)")
emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::amin : (Tensor, int[], bool) -> (Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)
emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True)
emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True)
@ -586,8 +562,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)")
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
# Functionalization ops
emit("aten::alias_copy : (Tensor) -> (Tensor)")
@ -608,8 +582,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::view_copy : (Tensor, int[]) -> (Tensor)")
emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)")
emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)")
emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)")
emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)")
emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)")
emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)")
@ -670,7 +642,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
@ -726,7 +697,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)")
emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)")
emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)")
emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)")
# ==========================================================================