Remove checking for training specific parameters in EmbeddingBag lowering (#3782)

Torch-to-linalg pass fails for `EmbeddingBag` when the training only
specific properties of the operator are set to `true.` For instance,
this operator's `sparse` input/property is training-specific, and if the
value of this property is `true,` the existing lowering bails out.
However, we don't need to check for training-specific parameters and
bailout from the legalization since we don't care about these properties
during the eval/inference mode.

---------

Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>
pull/3798/head
Hanumanth04 2024-10-15 09:37:26 -04:00 committed by GitHub
parent 1e431c6a90
commit 895f490cf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 26 deletions

View File

@ -222,23 +222,9 @@ public:
Value weight = adaptor.getWeight();
Value indices = adaptor.getIndices();
Value offsets = adaptor.getOffsets();
Value scaleGradByFreq = op.getScaleGradByFreq();
Value mode = op.getMode();
Value sparse = op.getSparse();
Value includeLastOffset = op.getIncludeLastOffset();
bool scaleGradByFreqBool;
if (!matchPattern(scaleGradByFreq,
m_TorchConstantBool(&scaleGradByFreqBool))) {
return rewriter.notifyMatchFailure(
op, "scale_grad_by_freq is expected to be a constant boolean value.");
}
if (scaleGradByFreqBool) {
return rewriter.notifyMatchFailure(
op, "Unimplemented: scale_grad_by_freq=True.");
}
int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(
@ -251,18 +237,6 @@ public:
"not supported yet for EmbeddingBag.");
}
bool isSparse;
if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) {
return rewriter.notifyMatchFailure(
op, "sparse is expected to be a constant boolean value.");
}
if (isSparse) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Sparse mode is not supported yet for EmbeddingBag.");
}
bool discardLastOffset;
if (!matchPattern(includeLastOffset,
m_TorchConstantBool(&discardLastOffset))) {

View File

@ -0,0 +1,52 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d1)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0)>
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-LABEL: func.func @torchAtenEmbeddingBagPaddingIdx
// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[1000000,64],f32>
// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[204790],si64>
// CHECK: %[[VAL_2:.*]]: !torch.vtensor<[2048],si64>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2048],si64> -> tensor<2048xi64>
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[204790],si64> -> tensor<204790xi64>
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1000000,64],f32> -> tensor<1000000x64xf32>
// CHECK-DAG: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK-DAG: %[[VAL_7:.*]] = torch.constant.int 0
// CHECK-DAG: %[[VAL_8:.*]] = torch.constant.bool true
func.func @torchAtenEmbeddingBagPaddingIdx(%weight: !torch.vtensor<[1000000,64],f32>,
%indices: !torch.vtensor<[204790],si64>,
%offsets: !torch.vtensor<[2048],si64>) -> (!torch.vtensor<[2048,64],f32>,
!torch.vtensor<[0],si64>,
!torch.vtensor<[2048],si64>,
!torch.vtensor<[2048],si64>)
{
%scale_grad_by_freq = torch.constant.bool true
%mode = torch.constant.int 0
%sparse = torch.constant.bool true
%per_sample_weights = torch.constant.none
%include_last_offset = torch.constant.bool false
%padding_idx = torch.constant.none
%result0, %result1, %result2, %result3 = torch.aten.embedding_bag.padding_idx %weight,
%indices,
%offsets,
%scale_grad_by_freq,
%mode,
%sparse,
%per_sample_weights,
%include_last_offset,
%padding_idx :
!torch.vtensor<[1000000,64],f32>,
!torch.vtensor<[204790],si64>,
!torch.vtensor<[2048],si64>,
!torch.bool,
!torch.int,
!torch.bool,
!torch.none,
!torch.bool,
!torch.none -> !torch.vtensor<[2048,64],f32>,
!torch.vtensor<[0],si64>,
!torch.vtensor<[2048],si64>,
!torch.vtensor<[2048],si64>
return %result0, %result1, %result2, %result3 : !torch.vtensor<[2048,64],f32>, !torch.vtensor<[0],si64>, !torch.vtensor<[2048],si64>, !torch.vtensor<[2048],si64>
}