diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 098d83b75..ae7cd832a 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -24,23 +24,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28 -static int64_t getDtypeIntegerFromMlirType(Type dtype) { - if (dtype.isa()) - return 6; - - if (dtype.isa()) - return 15; - - if (auto integerType = dtype.dyn_cast()) { - if (integerType.isSignedInteger(64)) - return 4; - if (integerType.isSignlessInteger(1)) - return 11; - } - return -1; -} - //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// @@ -1721,9 +1704,9 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { BaseTensorType tensorType = a().getType().cast(); if (tensorType.hasDtype()) { - int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype()); - if (dtypeInt != -1) - return getI64IntegerAttr(getContext(), dtypeInt); + torch_upstream::ScalarType scalarType = + Torch::getScalarTypeForType(tensorType.getDtype()); + return getI64IntegerAttr(getContext(), static_cast(scalarType)); } return nullptr; }