[MLIR][TORCH] Fix refine types crash

This commit fixes https://github.com/llvm/torch-mlir/issues/1599.

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1637/head
Vivek Khandelwal 2022-11-22 17:45:14 +05:30
parent 4aad5ccf39
commit da8fdc9f96
2 changed files with 29 additions and 1 deletions

View File

@ -646,7 +646,8 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType( void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
ValueKnowledge &knowledge, Value dtype, Type inputDType) { ValueKnowledge &knowledge, Value dtype, Type inputDType) {
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); assert(!inputDType ||
isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
int64_t dtypeInt; int64_t dtypeInt;
if (dtype.getType().isa<Torch::NoneType>()) if (dtype.getType().isa<Torch::NoneType>())
knowledge.dtype = inputDType; knowledge.dtype = inputDType;
@ -946,6 +947,12 @@ void TypeAnalysis::visitOperation(Operation *op,
} }
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) { if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype; Type defaultDtype = operands[0]->getValue().dtype;
if (!defaultDtype) {
incorporateKnowledge(
sumDimIntList.getResult(),
ValueKnowledge::getTensorPessimisticValueState(op->getContext()));
return;
}
// If the input dtype is bool, the result type should be i64. // If the input dtype is bool, the result type should be i64.
if (defaultDtype.isInteger(1)) if (defaultDtype.isInteger(1))
defaultDtype = defaultDtype =

View File

@ -272,3 +272,24 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor %2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
return return
} }
// -----
// CHECK-LABEL: func.func @torch.aten.zeros_like(
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) {
// CHECK: %[[INT6:.*]] = torch.constant.int 6
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32>
// CHECK: return
func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
%int6 = torch.constant.int 6
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%cpu = torch.constant.device "cpu"
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
return
}