mirror of https://github.com/llvm/torch-mlir
Avoid `using` the `torch_upstream` namespace.
This is code that we always want to treat as "foreign" and not get too comfortable using in many functions. One way to accomplish that is to make it a bit clunkier to use. Also, fix Utils.cpp to match the LLVM/MLIR coding conventions (don't define functions inside namespaces -- prefer `using` and explicit qualification).pull/669/head
parent
84a9693006
commit
7ea50a537a
|
@ -21,6 +21,10 @@
|
||||||
// original PyTorch license and the code here should not be mixed with "code
|
// original PyTorch license and the code here should not be mixed with "code
|
||||||
// that we [Torch-MLIR] write".
|
// that we [Torch-MLIR] write".
|
||||||
|
|
||||||
|
// Note: As a coding convention, we should never `using` the `torch_upstream`
|
||||||
|
// namespace. This is to ensure that at a glance from the code, it is clear
|
||||||
|
// that we are referencing upstream types.
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace torch_upstream {
|
namespace torch_upstream {
|
||||||
|
|
|
@ -20,7 +20,6 @@
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
using namespace mlir::torch::torch_upstream; // For ScalarType and type
|
|
||||||
|
|
||||||
// Helper funtion to get rank of `Base tensor type`.
|
// Helper funtion to get rank of `Base tensor type`.
|
||||||
// -1 is returned if the tensorRank can't be determined.
|
// -1 is returned if the tensorRank can't be determined.
|
||||||
|
|
|
@ -10,21 +10,19 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
|
||||||
using namespace mlir::torch::torch_upstream;
|
using namespace mlir;
|
||||||
|
using namespace mlir::torch;
|
||||||
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
namespace mlir {
|
int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) {
|
||||||
namespace torch {
|
|
||||||
namespace Torch {
|
|
||||||
|
|
||||||
int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
|
|
||||||
return dim >= 0 ? dim : dim + inputRank;
|
return dim >= 0 ? dim : dim + inputRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isValidDim(int64_t dim, int64_t inputRank) {
|
bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
|
||||||
return dim >= 0 && dim < inputRank;
|
return dim >= 0 && dim < inputRank;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||||
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
|
||||||
if (!listConstruct)
|
if (!listConstruct)
|
||||||
return false;
|
return false;
|
||||||
|
@ -32,30 +30,30 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarType getScalarTypeForType(Type type) {
|
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
if (type.isa<Float32Type>())
|
if (type.isa<Float32Type>())
|
||||||
return ScalarType::Float;
|
return torch_upstream::ScalarType::Float;
|
||||||
if (type.isa<Float64Type>())
|
if (type.isa<Float64Type>())
|
||||||
return ScalarType::Double;
|
return torch_upstream::ScalarType::Double;
|
||||||
if (type.isSignedInteger(64))
|
if (type.isSignedInteger(64))
|
||||||
return ScalarType::Long;
|
return torch_upstream::ScalarType::Long;
|
||||||
if (type.isSignedInteger(32))
|
if (type.isSignedInteger(32))
|
||||||
return ScalarType::Int;
|
return torch_upstream::ScalarType::Int;
|
||||||
if (type.isUnsignedInteger(1))
|
if (type.isUnsignedInteger(1))
|
||||||
return ScalarType::Bool;
|
return torch_upstream::ScalarType::Bool;
|
||||||
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
llvm::report_fatal_error("unhandled type for getScalarTypeForType");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
Type dtype) {
|
Type dtype) {
|
||||||
int intType = (int)getScalarTypeForType(dtype);
|
int intType = (int)getScalarTypeForType(dtype);
|
||||||
return rewriter.create<ConstantIntOp>(loc,
|
return rewriter.create<ConstantIntOp>(loc,
|
||||||
rewriter.getI64IntegerAttr(intType));
|
rewriter.getI64IntegerAttr(intType));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper to convert a tensor to a specific scalar type.
|
// Helper to convert a tensor to a specific scalar type.
|
||||||
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||||
Type dtype) {
|
Value input, Type dtype) {
|
||||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||||
|
@ -67,7 +65,3 @@ Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||||
return converted;
|
return converted;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace Torch
|
|
||||||
} // namespace torch
|
|
||||||
} // namespace mlir
|
|
||||||
|
|
Loading…
Reference in New Issue