mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Fix refine types crash
This commit fixes https://github.com/llvm/torch-mlir/issues/1599. Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>pull/1637/head
parent
4aad5ccf39
commit
da8fdc9f96
|
@ -646,7 +646,8 @@ static Type getPromotedResultTypeAssumingNonZeroRank(
|
||||||
|
|
||||||
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
void TypeAnalysis::fillInDTypeGivenDTypeIntAndInputDType(
|
||||||
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
ValueKnowledge &knowledge, Value dtype, Type inputDType) {
|
||||||
assert(isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
assert(!inputDType ||
|
||||||
|
isBuiltInType(inputDType) && "`inputDType` must be a builtin type");
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (dtype.getType().isa<Torch::NoneType>())
|
if (dtype.getType().isa<Torch::NoneType>())
|
||||||
knowledge.dtype = inputDType;
|
knowledge.dtype = inputDType;
|
||||||
|
@ -946,6 +947,12 @@ void TypeAnalysis::visitOperation(Operation *op,
|
||||||
}
|
}
|
||||||
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||||
Type defaultDtype = operands[0]->getValue().dtype;
|
Type defaultDtype = operands[0]->getValue().dtype;
|
||||||
|
if (!defaultDtype) {
|
||||||
|
incorporateKnowledge(
|
||||||
|
sumDimIntList.getResult(),
|
||||||
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext()));
|
||||||
|
return;
|
||||||
|
}
|
||||||
// If the input dtype is bool, the result type should be i64.
|
// If the input dtype is bool, the result type should be i64.
|
||||||
if (defaultDtype.isInteger(1))
|
if (defaultDtype.isInteger(1))
|
||||||
defaultDtype =
|
defaultDtype =
|
||||||
|
|
|
@ -272,3 +272,24 @@ func.func @forward(%arg0: !torch.bool, %arg1: !torch.tensor) {
|
||||||
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
%2 = torch.aten.mul.Tensor %1, %1 : !torch.vtensor, !torch.vtensor -> !torch.vtensor
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.zeros_like(
|
||||||
|
// CHECK-SAME: %[[arg:.*]]: !torch.vtensor) {
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
|
||||||
|
// CHECK: %[[ZEROS:.*]] = torch.aten.zeros_like %[[arg]], %[[INT6]], %[[INT0]], %[[CPU]], %[[FALSE]], %[[INT1]] : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor<*,f32>
|
||||||
|
// CHECK: return
|
||||||
|
func.func @torch.aten.zeros_like(%arg: !torch.vtensor) {
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%int1 = torch.constant.int 1
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%cpu = torch.constant.device "cpu"
|
||||||
|
%2 = torch.aten.zeros_like %arg, %int6, %int0, %cpu, %false, %int1 : !torch.vtensor, !torch.int, !torch.int, !torch.Device, !torch.bool, !torch.int -> !torch.vtensor
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue