mirror of https://github.com/llvm/torch-mlir
torch: use ScalarType enum instead of raw constants (#1020)
This patch replaces the use of raw integers like 6, 4, etc. (that represent PyTorch's scalar types) with named values from the ScalarType enum (e.g. `ScalarType::Float`, `ScalarType::Long`, etc.) in code for folding `prim.dtype` ops into numeric constants. This patch isn't strictly a non-functional change, since its use of `Torch::getScalarTypeForType()` implies that the input type has to be one among the supported types, otherwise compilation will abort, whereas previously, compilation proceeded without folding the unsupported data type into a numeric constant.pull/1013/head
parent
d38f2cae5b
commit
6491c69539
|
@ -24,23 +24,6 @@ using namespace mlir;
|
|||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
|
||||
// see https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h#L28
|
||||
static int64_t getDtypeIntegerFromMlirType(Type dtype) {
|
||||
if (dtype.isa<Float32Type>())
|
||||
return 6;
|
||||
|
||||
if (dtype.isa<BFloat16Type>())
|
||||
return 15;
|
||||
|
||||
if (auto integerType = dtype.dyn_cast<IntegerType>()) {
|
||||
if (integerType.isSignedInteger(64))
|
||||
return 4;
|
||||
if (integerType.isSignlessInteger(1))
|
||||
return 11;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1721,9 +1704,9 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
OpFoldResult PrimDtypeOp::fold(ArrayRef<Attribute> operands) {
|
||||
BaseTensorType tensorType = a().getType().cast<BaseTensorType>();
|
||||
if (tensorType.hasDtype()) {
|
||||
int64_t dtypeInt = getDtypeIntegerFromMlirType(tensorType.getDtype());
|
||||
if (dtypeInt != -1)
|
||||
return getI64IntegerAttr(getContext(), dtypeInt);
|
||||
torch_upstream::ScalarType scalarType =
|
||||
Torch::getScalarTypeForType(tensorType.getDtype());
|
||||
return getI64IntegerAttr(getContext(), static_cast<int64_t>(scalarType));
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue