Avoid `using` the `torch_upstream` namespace.

This is code that we always want to treat as "foreign" and not get too
comfortable using in many functions. One way to accomplish that is to
make it a bit clunkier to use.

Also, fix Utils.cpp to match the LLVM/MLIR coding conventions (don't
define functions inside namespaces -- prefer `using` and explicit
qualification).
pull/669/head
Sean Silva 2022-03-16 00:08:45 +00:00
parent 84a9693006
commit 7ea50a537a
3 changed files with 20 additions and 23 deletions

View File

@ -21,6 +21,10 @@
// original PyTorch license and the code here should not be mixed with "code // original PyTorch license and the code here should not be mixed with "code
// that we [Torch-MLIR] write". // that we [Torch-MLIR] write".
// Note: As a coding convention, we should never `using` the `torch_upstream`
// namespace. This is to ensure that at a glance from the code, it is clear
// that we are referencing upstream types.
namespace mlir { namespace mlir {
namespace torch { namespace torch {
namespace torch_upstream { namespace torch_upstream {

View File

@ -20,7 +20,6 @@
using namespace mlir; using namespace mlir;
using namespace mlir::torch; using namespace mlir::torch;
using namespace mlir::torch::Torch; using namespace mlir::torch::Torch;
using namespace mlir::torch::torch_upstream; // For ScalarType and type
// Helper funtion to get rank of `Base tensor type`. // Helper funtion to get rank of `Base tensor type`.
// -1 is returned if the tensorRank can't be determined. // -1 is returned if the tensorRank can't be determined.

View File

@ -10,21 +10,19 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
using namespace mlir::torch::torch_upstream; using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
namespace mlir { int64_t Torch::toPositiveDim(int64_t dim, int64_t inputRank) {
namespace torch {
namespace Torch {
int64_t toPositiveDim(int64_t dim, int64_t inputRank) {
return dim >= 0 ? dim : dim + inputRank; return dim >= 0 ? dim : dim + inputRank;
} }
bool isValidDim(int64_t dim, int64_t inputRank) { bool Torch::isValidDim(int64_t dim, int64_t inputRank) {
return dim >= 0 && dim < inputRank; return dim >= 0 && dim < inputRank;
} }
bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) { bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
auto listConstruct = v.getDefiningOp<PrimListConstructOp>(); auto listConstruct = v.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) if (!listConstruct)
return false; return false;
@ -32,30 +30,30 @@ bool getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
return true; return true;
} }
ScalarType getScalarTypeForType(Type type) { torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
if (type.isa<Float32Type>()) if (type.isa<Float32Type>())
return ScalarType::Float; return torch_upstream::ScalarType::Float;
if (type.isa<Float64Type>()) if (type.isa<Float64Type>())
return ScalarType::Double; return torch_upstream::ScalarType::Double;
if (type.isSignedInteger(64)) if (type.isSignedInteger(64))
return ScalarType::Long; return torch_upstream::ScalarType::Long;
if (type.isSignedInteger(32)) if (type.isSignedInteger(32))
return ScalarType::Int; return torch_upstream::ScalarType::Int;
if (type.isUnsignedInteger(1)) if (type.isUnsignedInteger(1))
return ScalarType::Bool; return torch_upstream::ScalarType::Bool;
llvm::report_fatal_error("unhandled type for getScalarTypeForType"); llvm::report_fatal_error("unhandled type for getScalarTypeForType");
} }
Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
Type dtype) { Type dtype) {
int intType = (int)getScalarTypeForType(dtype); int intType = (int)getScalarTypeForType(dtype);
return rewriter.create<ConstantIntOp>(loc, return rewriter.create<ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(intType)); rewriter.getI64IntegerAttr(intType));
} }
// Helper to convert a tensor to a specific scalar type. // Helper to convert a tensor to a specific scalar type.
Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Type dtype) { Value input, Type dtype) {
BaseTensorType origType = input.getType().cast<BaseTensorType>(); BaseTensorType origType = input.getType().cast<BaseTensorType>();
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype); Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
// `convertIntVal` contains the corresponding integer for the dtype which is // `convertIntVal` contains the corresponding integer for the dtype which is
@ -67,7 +65,3 @@ Value convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input,
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal); loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
return converted; return converted;
} }
} // namespace Torch
} // namespace torch
} // namespace mlir