From ed9bd556b3f59feb0fe1cca9776a3c2eb82fe826 Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Fri, 11 Feb 2022 12:59:02 +0000 Subject: [PATCH] Fix bug for aten_nll_loss op in the refine types pass The check for `self.hasSizes` was missing before performing `.size()` operation. --- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 30 +++++++++----------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 566c64a9e..091112b2f 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -1177,26 +1177,24 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp( auto self = operands[0]->getValue(); auto outputKnowledge = ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); - - // Contains Knowledge of shape and dtype for the 1st result. - outputKnowledge.dtype = self.dtype; - int64_t reduction; - unsigned resultRank = self.sizes.size(); - - // Contains Knowledge of shape and dtype for the 2nd result. auto totalWeightKnowledge = ValueKnowledge::getNotNonePessimisticValueState(op.getContext()); - totalWeightKnowledge.dtype = self.dtype; - totalWeightKnowledge.sizes.resize(0, kUnknownSize); - totalWeightKnowledge.hasSizes = true; - if (self.hasSizes && - matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) { - if (reduction != Reduction::None) - resultRank -= 1; + // `AtenNllLossForward` op returns two outputs, output and total_weight. + // The rank of the output depends on the reduction parameter and total_weight + // is a scalar value. + outputKnowledge.dtype = self.dtype; + totalWeightKnowledge.dtype = self.dtype; + totalWeightKnowledge.hasSizes = true; + if (self.hasSizes) { + int64_t reduction; + if (matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) { + outputKnowledge.hasSizes = true; + unsigned resultRank = self.sizes.size(); + if (reduction == Reduction::None) + outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize); + } } - outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize); - outputKnowledge.hasSizes = true; auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge); resultLattice |= getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);