mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix refine types crash
This commit fixes https://github.com/llvm/torch-mlir/issues/1599. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1637/head
parent
4aad5ccf39
commit
da8fdc9f96
|
@ -646,7 +646,8 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
|
|||
|
||||
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
||||
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
||||
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
||||
assert(!inputDType ||
|
||||
isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
||||
int64_t dtypeInt;
|
||||
if (dtype.getType().isa<Torch::NoneType>())
|
||||
knowledge.dtype = inputDType;
|
||||
|
@ -946,6 +947,12 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
}
|
||||
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
if (!defaultDtype) {
|
||||
incorporateKnowledge(
|
||||
sumDimIntList.getResult(),
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext()));
|
||||
return;
|
||||
}
|
||||
// If the input dtype is bool, the result type should be i64.
|
||||
if (defaultDtype.isInteger(1))
|
||||
defaultDtype =
|
||||
|
|
|
@ -272,3 +272,24 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
|
|||
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.zeros_like(
|
||||
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) {
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
|
||||
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32>
|
||||
// CHECK: return
|
||||
func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
|
||||
%int6 = torch.constant.int 6
|
||||
%false = torch.constant.bool false
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%cpu = torch.constant.device "cpu"
|
||||
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue