From 3dd29f9d5d1d0a722f70d5a20db05c6825e6d269 Mon Sep 17 00:00:00 2001 From: Gleb Kazantaev Date: Mon, 21 Aug 2023 06:36:39 -0400 Subject: [PATCH] Update Torch ODS list with new ops (#2361) * [LTC] Add shape_inference_(add|uniform) * Add torch.multinomial op. * Update ods gen; add normal_functional and erfinv ops support * New TorchMLIR ops: clamp_min.Tensor, clamp_max.Tensor, xlogy, binary_cross_entropy, log_sigmoid_forward, sigmoid_backward, cosine_embedding_loss, scatter.reduce * Improve the shape inference logic of whereOp - Infer the result tensor according to the broadcasting semantics Signed-off-by: rahul shrivastava * Added aten::sgn * Add shape inference logic for hardtanh_backward op * Added new Torch-MLIR ops Co-authored-by: GlebKazantaev * Add support for elu lowering * Add support for elu_backward lowering * Support fmod, remainder, and floor_divide Emit generated op defs for the remainder.Tensor and fmod.Tensor Add shape inference impelementations for remainder.Scalar, fmod.Scalar and floor_divide.Tensor * Add shape inference logic for im2col - pytorch.nn.unfold gets decomposed into im2col Signed-off-by: rahul shrivastava * Add aten::eye and aten::eye.m support * Add tracing for linalg_qr * Update GeneratedTorchOps.td * Update xfails * Fix unbound variable issue in torch_ods_gen --------- Signed-off-by: rahul shrivastava Co-authored-by: Mark Browning Co-authored-by: zihaoc-cerebras Co-authored-by: rahul shrivastava Co-authored-by: Gokul Ramakrishnan Co-authored-by: glebk-cerebras <111300564+glebk-cerebras@users.noreply.github.com> Co-authored-by: Behzad Abghari Co-authored-by: Ahmed Elkoushy --- build_tools/update_torch_ods.sh | 3 +- e2e_testing/xfail_sets.py | 27 - .../Dialect/Torch/IR/GeneratedTorchOps.td | 826 ++++++++++++++++++ .../base_lazy_backend/shape_inference.cpp | 250 ++++-- .../jit_ir/build_tools/torch_ods_gen.py | 36 +- 5 files changed, 1056 insertions(+), 86 deletions(-) diff --git a/build_tools/update_torch_ods.sh b/build_tools/update_torch_ods.sh index 6bc4b7109..e0564a62d 100755 --- a/build_tools/update_torch_ods.sh +++ b/build_tools/update_torch_ods.sh @@ -41,7 +41,8 @@ if [ ! -z ${TORCH_MLIR_EXT_MODULES} ]; then ext_module="${TORCH_MLIR_EXT_MODULES}" fi -PYTHONPATH="${pypath}" python \ +set +u +PYTHONPATH="${PYTHONPATH}:${pypath}" python \ -m torch_mlir.dialects.torch.importer.jit_ir.build_tools.torch_ods_gen \ --torch_ir_include_dir="${torch_ir_include_dir}" \ --pytorch_op_extensions="${ext_module}" \ diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index ff9da8d80..8224a6632 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -1278,8 +1278,6 @@ LTC_XFAIL_SET = { "BoolIntTrueModule_basic", "CeilFloatModule_basic", "DivFloatModule_basic", - "ElementwiseAtenFloorDivideBroadcastModule_basic", - "ElementwiseAtenFloorDivideModule_basic", "EqIntModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", @@ -1287,7 +1285,6 @@ LTC_XFAIL_SET = { "GtFloatIntModule_basic", "GtIntModule_basic", "HBC_basic", - "HardtanhBackward_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DFloatNonAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -1351,8 +1348,6 @@ LTC_XFAIL_SET = { "NeFloatIntModule_basic", "NeIntModule_basic", "QuantizedMLP_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", "RollModule_basic", "ScalarImplicitFloatModule_basic", "ScalarImplicitIntModule_basic", @@ -1371,20 +1366,14 @@ LTC_XFAIL_SET = { "TensorToIntZeroRank_basic", "TensorToInt_basic", "UniformModule_basic", - "UniformNoCorrelationModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", "UpSampleNearest2dBackwardVec_basic", "UpSampleNearest2dBackwardOutputSizeNone_basic", "ConvolutionBackwardModule2D_basic", @@ -1406,24 +1395,8 @@ LTC_XFAIL_SET = { "NativeDropoutTrainStaticShapeModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionNoneModule_basic", - "VarBiasedModule_basic", - "VarCorrectionAllDimReduceModule_basic", - "VarCorrectionEmptyDimModule_basic", "VarCorrectionKeepDimModule_basic", - "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", "VarCorrectionNoneModule_basic", - "VarCorrectionSingleDimReduceModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimBiasedModule_basic", - "VarDimEmptyDimModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", - "VarDimNoneDimModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", - "VarUnbiasedModule_basic", "AtenFloatScalarModule_basic", "PrimsSqueezeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7d362697e..b124cef10 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -113,6 +113,57 @@ def Torch_AtenHardtanh_Op : Torch_Op<"aten.hardtanh_", [ }]; } +def Torch_AtenEluOp : Torch_Op<"aten.elu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenEluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenElu_Op : Torch_Op<"aten.elu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::elu_ : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenElu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenElu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenReluOp : Torch_Op<"aten.relu", [ AllowsTypeRefinement, HasValueSemantics, @@ -385,6 +436,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ }]; } +def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sgn : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSgnOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSgnOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSgn_Op : Torch_Op<"aten.sgn_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sgn_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSgn_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSgn_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenHardsigmoidOp : Torch_Op<"aten.hardsigmoid", [ AllowsTypeRefinement, HasValueSemantics, @@ -520,6 +616,51 @@ def Torch_AtenErf_Op : Torch_Op<"aten.erf_", [ }]; } +def Torch_AtenErfinvOp : Torch_Op<"aten.erfinv", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::erfinv : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenErfinvOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenErfinvOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenErfinv_Op : Torch_Op<"aten.erfinv_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::erfinv_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenErfinv_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenErfinv_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSiluOp : Torch_Op<"aten.silu", [ AllowsTypeRefinement, HasValueSemantics, @@ -2200,6 +2341,53 @@ def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ }]; } +def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$min + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$min + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ AllowsTypeRefinement, HasValueSemantics, @@ -2247,6 +2435,53 @@ def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ }]; } +def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$max + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$max + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ AllowsTypeRefinement, HasValueSemantics, @@ -3456,6 +3691,30 @@ def Torch_AtenMishOp : Torch_Op<"aten.mish", [ }]; } +def Torch_AtenXlogyTensorOp : Torch_Op<"aten.xlogy.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenXlogyTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenXlogyTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRsubScalarOp : Torch_Op<"aten.rsub.Scalar", [ AllowsTypeRefinement, HasValueSemantics, @@ -3967,6 +4226,32 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ }]; } +def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$num_samples, + Torch_BoolType:$replacement, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMultinomialOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenMultinomialOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [ AllowsTypeRefinement, HasValueSemantics, @@ -5000,6 +5285,32 @@ def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ }]; } +def Torch_AtenNormalFunctionalOp : Torch_Op<"aten.normal_functional", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$mean, + Torch_FloatType:$std, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormalFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNormalFunctionalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ AllowsTypeRefinement, HasValueSemantics, @@ -6140,6 +6451,31 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ }]; } +def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$A, + Torch_StringType:$mode + ); + let results = (outs + AnyTorchTensorType:$Q, + AnyTorchTensorType:$R + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgQrOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenLinalgQrOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [ AllowsTypeRefinement, HasValueSemantics, @@ -6342,6 +6678,159 @@ def Torch_AtenNonzeroStaticOp : Torch_Op<"aten.nonzero_static", [ }]; } +def Torch_AtenBinaryCrossEntropyOp : Torch_Op<"aten.binary_cross_entropy", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBinaryCrossEntropyOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenBinaryCrossEntropyOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenBinaryCrossEntropyBackwardOp : Torch_Op<"aten.binary_cross_entropy_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$target, + AnyTorchOptionalTensorType:$weight, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenBinaryCrossEntropyBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenBinaryCrossEntropyBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenLogSigmoidForwardOp : Torch_Op<"aten.log_sigmoid_forward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$output, + AnyTorchTensorType:$buffer + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogSigmoidForwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 2); + } + void AtenLogSigmoidForwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 2); + } + }]; +} + +def Torch_AtenLogSigmoidBackwardOp : Torch_Op<"aten.log_sigmoid_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$self, + AnyTorchTensorType:$buffer + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLogSigmoidBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenSigmoidBackwardOp : Torch_Op<"aten.sigmoid_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchTensorType:$output + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSigmoidBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSigmoidBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input1, + AnyTorchTensorType:$input2, + AnyTorchTensorType:$target, + Torch_FloatType:$margin, + Torch_IntType:$reduction + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosineEmbeddingLossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenCosineEmbeddingLossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ AllowsTypeRefinement, HasValueSemantics, @@ -6669,6 +7158,61 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ }]; } +def Torch_AtenEyeOp : Torch_Op<"aten.eye", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$n, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEyeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenEyeOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenEyeMOp : Torch_Op<"aten.eye.m", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)`"; + let arguments = (ins + Torch_IntType:$n, + Torch_IntType:$m, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEyeMOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEyeMOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -6866,6 +7410,31 @@ def Torch_AtenAllBoolOp : Torch_Op<"aten.all.bool", [ }]; } +def Torch_AtenAllDimOp : Torch_Op<"aten.all.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::all.dim : (Tensor, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAllDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAllDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenAnyOp : Torch_Op<"aten.any", [ AllowsTypeRefinement, HasValueSemantics, @@ -7047,6 +7616,31 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ }]; } +def Torch_AtenArgminOp : Torch_Op<"aten.argmin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::argmin : (Tensor, int?, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenArgminOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenArgminOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [ AllowsTypeRefinement, HasValueSemantics, @@ -8137,6 +8731,80 @@ def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [ }]; } +def Torch_AtenMinOp : Torch_Op<"aten.min", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::min : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenMinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenMinDimOp : Torch_Op<"aten.min.dim", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$values, + AnyTorchTensorType:$indices + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMinDimOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 2); + } + void AtenMinDimOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 2); + } + }]; +} + +def Torch_AtenAminOp : Torch_Op<"aten.amin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::amin : (Tensor, int[], bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$dim, + Torch_BoolType:$keepdim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAminOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenAminOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ AllowsTypeRefinement, ReadOnly @@ -9026,6 +9694,58 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } +def Torch_AtenFmodTensorOp : Torch_Op<"aten.fmod.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenFmodTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenFmodTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenUniqueConsecutiveOp : Torch_Op<"aten.unique_consecutive", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_BoolType:$return_inverse, + Torch_BoolType:$return_counts, + AnyTorchOptionalIntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenUniqueConsecutiveOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 3); + } + void AtenUniqueConsecutiveOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 3); + } + }]; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, @@ -9466,6 +10186,60 @@ def Torch_AtenUnfoldCopyOp : Torch_Op<"aten.unfold_copy", [ }]; } +def Torch_AtenIm2colOp : Torch_Op<"aten.im2col", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$stride + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIm2colOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIm2colOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenScatterReduceOp : Torch_Op<"aten.scatter.reduce", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$src, + Torch_StringType:$reduce + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenScatterReduceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenScatterReduceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenSelectScatterOp : Torch_Op<"aten.select_scatter", [ AllowsTypeRefinement, HasValueSemantics, @@ -10681,6 +11455,30 @@ def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [ }]; } +def Torch_AtenRemainderTensorOp : Torch_Op<"aten.remainder.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRemainderTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenRemainderTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -11938,6 +12736,34 @@ def Torch_AtenNativeDropoutBackwardOp : Torch_Op<"aten.native_dropout_backward", }]; } +def Torch_AtenEluBackwardOp : Torch_Op<"aten.elu_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$grad_output, + AnyTorchScalarType:$alpha, + AnyTorchScalarType:$scale, + AnyTorchScalarType:$input_scale, + Torch_BoolType:$is_result, + AnyTorchTensorType:$self_or_result + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEluBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenEluBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp index 97d35cdcd..a55e8cd52 100644 --- a/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp +++ b/python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include @@ -17,47 +18,85 @@ namespace torch { namespace lazy { -// TODO(henrytu): Upstream these shape inference functions to PyTorch in the future. +// TODO(henrytu): Upstream these shape inference functions to PyTorch in the +// future. -std::vector -compute_shape_div(const at::Tensor& self, const at::Scalar& other) { +std::vector compute_shape_add(const at::Tensor& self, + const at::Scalar& other, + const at::Scalar& alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector -compute_shape_mse_loss_backward( - const at::Tensor& grad_output, - const at::Tensor& self, - const at::Tensor& target, - int64_t reduction) { +std::vector compute_shape_sub(const at::Tensor& self, + const at::Scalar& other, + const at::Scalar& alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector -compute_shape_mul(const at::Tensor& self, const at::Scalar& other) { +std::vector compute_shape_div(const at::Tensor& self, + const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_mse_loss_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Tensor& target, int64_t reduction) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_mul(const at::Tensor& self, + const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_var( const at::Tensor& self, at::OptionalIntArrayRef dim, - c10::optional correction, bool keepdim) { + const c10::optional & correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } std::vector compute_shape_hardtanh( - const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val -) { + const at::Tensor& self, const at::Scalar& min_val, + const at::Scalar& max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_where( - const at::Tensor & condition, - const at::Tensor & self, - const at::Tensor & other) { +std::vector compute_shape_hardtanh_backward( + const at::Tensor& grad_output, const at::Tensor& self, + const at::Scalar& min_val, const at::Scalar& max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_where(const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + // There are cases like - + // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, + // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. + // So the result tensor would the biggest of all the three operands. + auto condition_meta = at::native::empty_strided_meta_symint( + condition.sym_sizes(), condition.sym_strides(), + /*dtype=*/c10::make_optional(condition.scalar_type()), + /*layout=*/c10::make_optional(condition.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto self_meta = at::native::empty_strided_meta_symint( + self.sym_sizes(), self.sym_strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto other_meta = at::native::empty_strided_meta_symint( + other.sym_sizes(), other.sym_strides(), + /*dtype=*/c10::make_optional(other.scalar_type()), + /*layout=*/c10::make_optional(other.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + auto out_meta = at::where(condition_meta, self_meta, other_meta); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + std::vector compute_shape_bucketize( const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, bool right) { @@ -65,50 +104,64 @@ std::vector compute_shape_bucketize( return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_copy( - const at::Tensor& self, - const at::Tensor& src, - bool non_blocking) { +std::vector compute_shape_copy(const at::Tensor& self, + const at::Tensor& src, + bool non_blocking) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_floor_divide( + const at::Tensor& self, const at::Tensor& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_fmod(const at::Tensor& self, + const at::Scalar& other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( - const at::Tensor& input, - const c10::optional& weight, - const c10::optional& bias, - int64_t N, int64_t C, int64_t HxW, - int64_t group, double eps) { + const at::Tensor& input, const c10::optional& weight, + const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, + int64_t group, double eps) { - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); + TORCH_CHECK(input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); // A separate mean and var needs to be kept for each group per N. - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{N, group}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{N, group}); - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{N, group}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{N, group}); return shapes; } -std::vector compute_shape_native_group_norm_backward( - const at::Tensor& grad_out, - const at::Tensor& input, - const at::Tensor& mean, - const at::Tensor& rstd, - const c10::optional& weight, - int64_t N, int64_t C, int64_t HxW, - int64_t group, ::std::array output_mask) { +std::vector compute_shape_im2col( + const at::Tensor& self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { - TORCH_CHECK( - input.sizes().size() >= 2, - "Input tensor must have at least batch and channel dimensions!"); + auto self_meta = at::native::empty_strided_meta_symint( + self.sym_sizes(), self.sym_strides(), + /*dtype=*/c10::make_optional(self.scalar_type()), + /*layout=*/c10::make_optional(self.layout()), + /*device=*/c10::make_optional(c10::Device(c10::kMeta)), + /*pin_memory=*/c10::nullopt); + + auto out_meta = at::im2col(self_meta, kernel_size, dilation, padding, stride); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_native_group_norm_backward( + const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, + const at::Tensor& rstd, const c10::optional& weight, int64_t N, + int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + + TORCH_CHECK(input.sizes().size() >= 2, + "Input tensor must have at least batch and channel dimensions!"); std::vector shapes; shapes.reserve(3); shapes.emplace_back(input.scalar_type(), input.sizes().vec()); @@ -116,15 +169,102 @@ std::vector compute_shape_native_group_norm_backward( int64_t num_features = input.size(1); // `weight` and `bias` are vectors of length C (number of channels)` - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); - shapes.emplace_back( - at::get_default_dtype_as_scalartype(), - std::vector{num_features}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{num_features}); + shapes.emplace_back(at::get_default_dtype_as_scalartype(), + std::vector{num_features}); return shapes; } +std::vector compute_shape_remainder( + const at::Tensor& self, const at::Scalar& other) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} -} // namespace lazy -} // namespace torch +std::vector compute_shape_uniform( + const at::Tensor& self, double from, double to, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_normal_functional( + const at::Tensor& self, double mean, double std, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_multinomial( + const at::Tensor& self, int64_t num_samples, bool replacement, + c10::optional generator) { + // Input tensor can be either 1D or 2D. The last dim of output + // should be 'num_samples'. So the output shape can be either + // [num_samples] or [m, num_samples]. + // Output type can only be long tensor. + auto ishape = self.sizes().vec(); + ishape.back() = num_samples; + return {Shape(at::kLong, ishape)}; +} + +std::vector compute_shape_eye( + int64_t n, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_eye( + int64_t n, int64_t m, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + auto out_meta = + at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); + return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; +} + +std::vector compute_shape_full( + at::IntArrayRef size, const at::Scalar& fill_value, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_fill(const at::Tensor& self, + const at::Tensor& value) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +std::vector compute_shape_randn( + at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_randint( + int64_t high, at::IntArrayRef size, c10::optional dtype, + c10::optional layout, c10::optional device, + c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_randint( + int64_t low, int64_t high, at::IntArrayRef size, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + return { + Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; +} + +std::vector compute_shape_bernoulli( + const at::Tensor& self, const at::Tensor &p, + c10::optional generator) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + +} // namespace lazy +} // namespace torch \ No newline at end of file diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index f010c92f0..08671dd90 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -242,15 +242,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): for key in [ "aten::tanh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", + "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sign : (Tensor) -> (Tensor)", + "aten::sgn : (Tensor) -> (Tensor)", "aten::hardsigmoid : (Tensor) -> (Tensor)", "aten::hardswish : (Tensor) -> (Tensor)", "aten::erf : (Tensor) -> (Tensor)", + "aten::erfinv : (Tensor) -> (Tensor)", "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", @@ -287,7 +290,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", "aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)", "aten::clamp_min : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::clamp_max : (Tensor, Scalar) -> (Tensor)", + "aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::log2 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", @@ -309,17 +314,18 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): # variants. emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - + emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::maximum : (Tensor, Tensor) -> (Tensor)") emit("aten::minimum : (Tensor, Tensor) -> (Tensor)") emit("aten::mish : (Tensor) -> (Tensor)") + emit("aten::xlogy.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::rsub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") @@ -344,6 +350,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)") @@ -393,6 +400,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) + emit( + "aten::normal_functional : (Tensor, float, float, Generator?) -> (Tensor)", + ) emit( "aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)" ) @@ -452,6 +462,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") + emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") @@ -460,6 +471,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::nonzero : (Tensor) -> (Tensor)") emit("aten::nonzero_numpy : (Tensor) -> (Tensor[])") emit("aten::nonzero_static : (Tensor, int, int) -> (Tensor)") + emit("aten::binary_cross_entropy : (Tensor, Tensor, Tensor?, int) -> (Tensor)") + emit("aten::binary_cross_entropy_backward : (Tensor, Tensor, Tensor, Tensor?, int) -> (Tensor)") + emit("aten::log_sigmoid_forward : (Tensor) -> (Tensor, Tensor)") + emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)") + emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)") # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") @@ -475,6 +492,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") @@ -483,6 +502,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") + emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::any : (Tensor) -> (Tensor)") emit("aten::any.dim : (Tensor, int, bool) -> (Tensor)") emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") @@ -490,6 +510,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") + emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)") @@ -531,6 +552,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::max : (Tensor) -> (Tensor)") emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)") emit("aten::amax : (Tensor, int[], bool) -> (Tensor)") + emit("aten::min : (Tensor) -> (Tensor)") + emit("aten::min.dim : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::amin : (Tensor, int[], bool) -> (Tensor)") emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True) emit("aten::to.dtype_layout : (Tensor, int?, int?, Device?, bool?, bool, bool, int?) -> (Tensor)", has_folder=True, has_canonicalizer = True) emit("aten::to.other : (Tensor, Tensor, bool, bool, int?) -> (Tensor)", has_canonicalizer=True) @@ -562,6 +586,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") + emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") + emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") @@ -582,6 +608,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::view_copy : (Tensor, int[]) -> (Tensor)") emit("aten::view_copy.dtype : (Tensor, int) -> (Tensor)") emit("aten::unfold_copy : (Tensor, int, int, int) -> (Tensor)") + emit("aten::im2col : (Tensor, int[], int[], int[], int[]) -> (Tensor)") + emit("aten::scatter.reduce : (Tensor, int, Tensor, Tensor, str) -> (Tensor)") emit("aten::select_scatter : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::slice_scatter : (Tensor, Tensor, int, int?, int?, int) -> (Tensor)") emit("aten::diagonal_scatter : (Tensor, Tensor, int, int, int) -> (Tensor)") @@ -642,6 +670,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)") + emit("aten::remainder.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True) @@ -697,6 +726,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::native_batch_norm_backward : (Tensor, Tensor, Tensor?, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::native_group_norm_backward : (Tensor, Tensor, Tensor, Tensor, Tensor?, int, int, int, int, bool[]) -> (Tensor, Tensor, Tensor)") emit("aten::native_dropout_backward : (Tensor, Tensor, float) -> (Tensor)") + emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") # ==========================================================================