Fix bug for aten_nll_loss op in the refine types pass

The check for `self.hasSizes` was missing before performing `.size()`
operation.
pull/608/head
Prashant Kumar 2022-02-11 12:59:02 +00:00
parent f8cb32faf0
commit ed9bd556b3
1 changed files with 14 additions and 16 deletions

View File

@ -1177,26 +1177,24 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
auto self = operands[0]->getValue();
auto outputKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
// Contains Knowledge of shape and dtype for the 1st result.
outputKnowledge.dtype = self.dtype;
int64_t reduction;
unsigned resultRank = self.sizes.size();
// Contains Knowledge of shape and dtype for the 2nd result.
auto totalWeightKnowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
totalWeightKnowledge.dtype = self.dtype;
totalWeightKnowledge.sizes.resize(0, kUnknownSize);
totalWeightKnowledge.hasSizes = true;
if (self.hasSizes &&
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
if (reduction != Reduction::None)
resultRank -= 1;
// `AtenNllLossForward` op returns two outputs, output and total_weight.
// The rank of the output depends on the reduction parameter and total_weight
// is a scalar value.
outputKnowledge.dtype = self.dtype;
totalWeightKnowledge.dtype = self.dtype;
totalWeightKnowledge.hasSizes = true;
if (self.hasSizes) {
int64_t reduction;
if (matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
outputKnowledge.hasSizes = true;
unsigned resultRank = self.sizes.size();
if (reduction == Reduction::None)
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
}
}
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
outputKnowledge.hasSizes = true;
auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge);
resultLattice |=
getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);