mirror of https://github.com/llvm/torch-mlir
Remove checking for training specific parameters in EmbeddingBag lowering (#3782)
Torch-to-linalg pass fails for `EmbeddingBag` when the training only specific properties of the operator are set to `true.` For instance, this operator's `sparse` input/property is training-specific, and if the value of this property is `true,` the existing lowering bails out. However, we don't need to check for training-specific parameters and bailout from the legalization since we don't care about these properties during the eval/inference mode. --------- Co-authored-by: Hanumanth Hanumantharayappa <hhanuman@ah-hhanuman-l.dhcp.mathworks.com>pull/3798/head
parent
1e431c6a90
commit
895f490cf5
|
@ -222,23 +222,9 @@ public:
|
|||
Value weight = adaptor.getWeight();
|
||||
Value indices = adaptor.getIndices();
|
||||
Value offsets = adaptor.getOffsets();
|
||||
Value scaleGradByFreq = op.getScaleGradByFreq();
|
||||
Value mode = op.getMode();
|
||||
Value sparse = op.getSparse();
|
||||
Value includeLastOffset = op.getIncludeLastOffset();
|
||||
|
||||
bool scaleGradByFreqBool;
|
||||
if (!matchPattern(scaleGradByFreq,
|
||||
m_TorchConstantBool(&scaleGradByFreqBool))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "scale_grad_by_freq is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
if (scaleGradByFreqBool) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: scale_grad_by_freq=True.");
|
||||
}
|
||||
|
||||
int64_t modeInt;
|
||||
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -251,18 +237,6 @@ public:
|
|||
"not supported yet for EmbeddingBag.");
|
||||
}
|
||||
|
||||
bool isSparse;
|
||||
if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sparse is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
if (isSparse) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Unimplemented: Sparse mode is not supported yet for EmbeddingBag.");
|
||||
}
|
||||
|
||||
bool discardLastOffset;
|
||||
if (!matchPattern(includeLastOffset,
|
||||
m_TorchConstantBool(&discardLastOffset))) {
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d1)>
|
||||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
|
||||
// CHECK-LABEL: func.func @torchAtenEmbeddingBagPaddingIdx
|
||||
// CHECK: %[[VAL_0:.*]]: !torch.vtensor<[1000000,64],f32>
|
||||
// CHECK: %[[VAL_1:.*]]: !torch.vtensor<[204790],si64>
|
||||
// CHECK: %[[VAL_2:.*]]: !torch.vtensor<[2048],si64>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[2048],si64> -> tensor<2048xi64>
|
||||
// CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[204790],si64> -> tensor<204790xi64>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1000000,64],f32> -> tensor<1000000x64xf32>
|
||||
// CHECK-DAG: %[[VAL_6:.*]] = torch.constant.bool true
|
||||
// CHECK-DAG: %[[VAL_7:.*]] = torch.constant.int 0
|
||||
// CHECK-DAG: %[[VAL_8:.*]] = torch.constant.bool true
|
||||
func.func @torchAtenEmbeddingBagPaddingIdx(%weight: !torch.vtensor<[1000000,64],f32>,
|
||||
%indices: !torch.vtensor<[204790],si64>,
|
||||
%offsets: !torch.vtensor<[2048],si64>) -> (!torch.vtensor<[2048,64],f32>,
|
||||
!torch.vtensor<[0],si64>,
|
||||
!torch.vtensor<[2048],si64>,
|
||||
!torch.vtensor<[2048],si64>)
|
||||
{
|
||||
%scale_grad_by_freq = torch.constant.bool true
|
||||
%mode = torch.constant.int 0
|
||||
%sparse = torch.constant.bool true
|
||||
%per_sample_weights = torch.constant.none
|
||||
%include_last_offset = torch.constant.bool false
|
||||
%padding_idx = torch.constant.none
|
||||
%result0, %result1, %result2, %result3 = torch.aten.embedding_bag.padding_idx %weight,
|
||||
%indices,
|
||||
%offsets,
|
||||
%scale_grad_by_freq,
|
||||
%mode,
|
||||
%sparse,
|
||||
%per_sample_weights,
|
||||
%include_last_offset,
|
||||
%padding_idx :
|
||||
!torch.vtensor<[1000000,64],f32>,
|
||||
!torch.vtensor<[204790],si64>,
|
||||
!torch.vtensor<[2048],si64>,
|
||||
!torch.bool,
|
||||
!torch.int,
|
||||
!torch.bool,
|
||||
!torch.none,
|
||||
!torch.bool,
|
||||
!torch.none -> !torch.vtensor<[2048,64],f32>,
|
||||
!torch.vtensor<[0],si64>,
|
||||
!torch.vtensor<[2048],si64>,
|
||||
!torch.vtensor<[2048],si64>
|
||||
|
||||
return %result0, %result1, %result2, %result3 : !torch.vtensor<[2048,64],f32>, !torch.vtensor<[0],si64>, !torch.vtensor<[2048],si64>, !torch.vtensor<[2048],si64>
|
||||
}
|
Loading…
Reference in New Issue