From 895f490cf5bba9c85a056853da4309d3ea633857 Mon Sep 17 00:00:00 2001 From: Hanumanth04 Date: Tue, 15 Oct 2024 09:37:26 -0400 Subject: [PATCH] Remove checking for training specific parameters in EmbeddingBag lowering (#3782) Torch-to-linalg pass fails for `EmbeddingBag` when the training only specific properties of the operator are set to `true.` For instance, this operator's `sparse` input/property is training-specific, and if the value of this property is `true,` the existing lowering bails out. However, we don't need to check for training-specific parameters and bailout from the legalization since we don't care about these properties during the eval/inference mode. --------- Co-authored-by: Hanumanth Hanumantharayappa --- .../TorchToLinalg/IndirectDataMovement.cpp | 26 ---------- .../TorchToLinalg/embeddingBag.mlir | 52 +++++++++++++++++++ 2 files changed, 52 insertions(+), 26 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/embeddingBag.mlir diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index fbc5004c9..07e4b23a1 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -222,23 +222,9 @@ public: Value weight = adaptor.getWeight(); Value indices = adaptor.getIndices(); Value offsets = adaptor.getOffsets(); - Value scaleGradByFreq = op.getScaleGradByFreq(); Value mode = op.getMode(); - Value sparse = op.getSparse(); Value includeLastOffset = op.getIncludeLastOffset(); - bool scaleGradByFreqBool; - if (!matchPattern(scaleGradByFreq, - m_TorchConstantBool(&scaleGradByFreqBool))) { - return rewriter.notifyMatchFailure( - op, "scale_grad_by_freq is expected to be a constant boolean value."); - } - - if (scaleGradByFreqBool) { - return rewriter.notifyMatchFailure( - op, "Unimplemented: scale_grad_by_freq=True."); - } - int64_t modeInt; if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) { return rewriter.notifyMatchFailure( @@ -251,18 +237,6 @@ public: "not supported yet for EmbeddingBag."); } - bool isSparse; - if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) { - return rewriter.notifyMatchFailure( - op, "sparse is expected to be a constant boolean value."); - } - - if (isSparse) { - return rewriter.notifyMatchFailure( - op, - "Unimplemented: Sparse mode is not supported yet for EmbeddingBag."); - } - bool discardLastOffset; if (!matchPattern(includeLastOffset, m_TorchConstantBool(&discardLastOffset))) { diff --git a/test/Conversion/TorchToLinalg/embeddingBag.mlir b/test/Conversion/TorchToLinalg/embeddingBag.mlir new file mode 100644 index 000000000..05aa57fc7 --- /dev/null +++ b/test/Conversion/TorchToLinalg/embeddingBag.mlir @@ -0,0 +1,52 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d1)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-LABEL: func.func @torchAtenEmbeddingBagPaddingIdx +// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[1000000,64],f32> +// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[204790],si64> +// CHECK: %[[VAL_2:.*]]: !torch.vtensor<[2048],si64> +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2048],si64> -> tensor<2048xi64> +// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[204790],si64> -> tensor<204790xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1000000,64],f32> -> tensor<1000000x64xf32> +// CHECK-DAG: %[[VAL_6:.*]] = torch.constant.bool true +// CHECK-DAG: %[[VAL_7:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[VAL_8:.*]] = torch.constant.bool true +func.func @torchAtenEmbeddingBagPaddingIdx(%weight: !torch.vtensor<[1000000,64],f32>, + %indices: !torch.vtensor<[204790],si64>, + %offsets: !torch.vtensor<[2048],si64>) -> (!torch.vtensor<[2048,64],f32>, + !torch.vtensor<[0],si64>, + !torch.vtensor<[2048],si64>, + !torch.vtensor<[2048],si64>) + { + %scale_grad_by_freq = torch.constant.bool true + %mode = torch.constant.int 0 + %sparse = torch.constant.bool true + %per_sample_weights = torch.constant.none + %include_last_offset = torch.constant.bool false + %padding_idx = torch.constant.none + %result0, %result1, %result2, %result3 = torch.aten.embedding_bag.padding_idx %weight, + %indices, + %offsets, + %scale_grad_by_freq, + %mode, + %sparse, + %per_sample_weights, + %include_last_offset, + %padding_idx : + !torch.vtensor<[1000000,64],f32>, + !torch.vtensor<[204790],si64>, + !torch.vtensor<[2048],si64>, + !torch.bool, + !torch.int, + !torch.bool, + !torch.none, + !torch.bool, + !torch.none -> !torch.vtensor<[2048,64],f32>, + !torch.vtensor<[0],si64>, + !torch.vtensor<[2048],si64>, + !torch.vtensor<[2048],si64> + + return %result0, %result1, %result2, %result3 : !torch.vtensor<[2048,64],f32>, !torch.vtensor<[0],si64>, !torch.vtensor<[2048],si64>, !torch.vtensor<[2048],si64> +}