torch: use ScalarType enum instead of raw constants (#1020)

This patch replaces the use of raw integers like 6, 4, etc. (that
represent PyTorch's scalar types) with named values from the ScalarType
enum (e.g. `ScalarType::Float`, `ScalarType::Long`, etc.) in code for
folding `prim.dtype` ops into numeric constants.

This patch isn't strictly a non-functional change, since its use of
`Torch::getScalarTypeForType()` implies that the input type has to be
one among the supported types, otherwise compilation will abort, whereas
previously, compilation proceeded without folding the unsupported data
type into a numeric constant.
pull/1013/head
Ashay Rane 2022-07-07 14:21:05 -07:00 committed by GitHub
parent d38f2cae5b
commit 6491c69539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 20 deletions

View File

@ -24,23 +24,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
if (dtype.isa<Float32Type>())
return 6;
if (dtype.isa<BFloat16Type>())
return 15;
if (auto integerType = dtype.dyn_cast<IntegerType>()) {
if (integerType.isSignedInteger(64))
return 4;
if (integerType.isSignlessInteger(1))
return 11;
}
return -1;
}
//===----------------------------------------------------------------------===//
// Utilities
//===----------------------------------------------------------------------===//
@ -1721,9 +1704,9 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
if (tensorType.hasDtype()) {
int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype());
if (dtypeInt != -1)
return getI64IntegerAttr(getContext(), dtypeInt);
torch_upstream::ScalarType scalarType =
Torch::getScalarTypeForType(tensorType.getDtype());
return getI64IntegerAttr(getContext(), static_cast<int64_t>(scalarType));
}
return nullptr;
}