diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 1520c2baf..b0d961ced 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -646,7 +646,8 @@ static Type getPromotedResultTypeAssumingNonZeroRank( void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType( ValueKnowledge &knowledge, Value dtype, Type inputDType) { - assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); + assert(!inputDType || + isBuiltInType(inputDType) && "`inputDType` must be a builtin type"); int64_t dtypeInt; if (dtype.getType().isa()) knowledge.dtype = inputDType; @@ -946,6 +947,12 @@ void TypeAnalysis::visitOperation(Operation *op, } if (auto sumDimIntList = dyn_cast(op)) { Type defaultDtype = operands[0]->getValue().dtype; + if (!defaultDtype) { + incorporateKnowledge( + sumDimIntList.getResult(), + ValueKnowledge::getTensorPessimisticValueState(op->getContext())); + return; + } // If the input dtype is bool, the result type should be i64. if (defaultDtype.isInteger(1)) defaultDtype = diff --git a/test/Dialect/Torch/refine-types.mlir b/test/Dialect/Torch/refine-types.mlir index b8ee124fb..52f30b1bc 100644 --- a/test/Dialect/Torch/refine-types.mlir +++ b/test/Dialect/Torch/refine-types.mlir @@ -272,3 +272,24 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) { %2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor return } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.zeros_like( +// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) { +// CHECK: %[[INT6:.*]] = torch.constant.int 6 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" +// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32> +// CHECK: return +func.func @torch.aten.zeros_like(%arg: !torch.vtensor) { + %int6 = torch.constant.int 6 + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %cpu = torch.constant.device "cpu" + %2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor + return +}