From 6491c69539632a52a48171297363e53ce4746f97 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Thu, 7 Jul 2022 14:21:05 -0700 Subject: [PATCH] torch: use ScalarType enum instead of raw constants (#1020) This patch replaces the use of raw integers like 6, 4, etc. (that represent PyTorch's scalar types) with named values from the ScalarType enum (e.g. `ScalarType::Float`, `ScalarType::Long`, etc.) in code for folding `prim.dtype` ops into numeric constants. This patch isn't strictly a non-functional change, since its use of `Torch::getScalarTypeForType()` implies that the input type has to be one among the supported types, otherwise compilation will abort, whereas previously, compilation proceeded without folding the unsupported data type into a numeric constant. --- lib/Dialect/Torch/IR/TorchOps.cpp | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 098d83b75..ae7cd832a 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -24,23 +24,6 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28 -static int64_t getDtypeIntegerFromMlirType(Type dtype) { - if (dtype.isa()) - return 6; - - if (dtype.isa()) - return 15; - - if (auto integerType = dtype.dyn_cast()) { - if (integerType.isSignedInteger(64)) - return 4; - if (integerType.isSignlessInteger(1)) - return 11; - } - return -1; -} - //===----------------------------------------------------------------------===// // Utilities //===----------------------------------------------------------------------===// @@ -1721,9 +1704,9 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { OpFoldResult PrimDtypeOp::fold(ArrayRef operands) { BaseTensorType tensorType = a().getType().cast(); if (tensorType.hasDtype()) { - int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype()); - if (dtypeInt != -1) - return getI64IntegerAttr(getContext(), dtypeInt); + torch_upstream::ScalarType scalarType = + Torch::getScalarTypeForType(tensorType.getDtype()); + return getI64IntegerAttr(getContext(), static_cast(scalarType)); } return nullptr; }