mirror of https://github.com/llvm/torch-mlir
Fix bug for aten_nll_loss op in the refine types pass
The check for `self.hasSizes` was missing before performing `.size()` operation.pull/608/head
parent
f8cb32faf0
commit
ed9bd556b3
|
@ -1177,26 +1177,24 @@ ChangeResult TypeAnalyzer::visitAtenNllLossForwardOp(
|
|||
auto self = operands[0]->getValue();
|
||||
auto outputKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
|
||||
// Contains Knowledge of shape and dtype for the 1st result.
|
||||
outputKnowledge.dtype = self.dtype;
|
||||
int64_t reduction;
|
||||
unsigned resultRank = self.sizes.size();
|
||||
|
||||
// Contains Knowledge of shape and dtype for the 2nd result.
|
||||
auto totalWeightKnowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
totalWeightKnowledge.dtype = self.dtype;
|
||||
totalWeightKnowledge.sizes.resize(0, kUnknownSize);
|
||||
totalWeightKnowledge.hasSizes = true;
|
||||
|
||||
if (self.hasSizes &&
|
||||
matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
|
||||
if (reduction != Reduction::None)
|
||||
resultRank -= 1;
|
||||
// `AtenNllLossForward` op returns two outputs, output and total_weight.
|
||||
// The rank of the output depends on the reduction parameter and total_weight
|
||||
// is a scalar value.
|
||||
outputKnowledge.dtype = self.dtype;
|
||||
totalWeightKnowledge.dtype = self.dtype;
|
||||
totalWeightKnowledge.hasSizes = true;
|
||||
if (self.hasSizes) {
|
||||
int64_t reduction;
|
||||
if (matchPattern(op.reduction(), m_TorchConstantInt(&reduction))) {
|
||||
outputKnowledge.hasSizes = true;
|
||||
unsigned resultRank = self.sizes.size();
|
||||
if (reduction == Reduction::None)
|
||||
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
|
||||
}
|
||||
}
|
||||
outputKnowledge.sizes.resize(resultRank - 1, kUnknownSize);
|
||||
outputKnowledge.hasSizes = true;
|
||||
auto resultLattice = getLatticeElement(op.getResult(0)).join(outputKnowledge);
|
||||
resultLattice |=
|
||||
getLatticeElement(op.getResult(1)).join(totalWeightKnowledge);
|
||||
|
|
Loading…
Reference in New Issue