From 494089d53db4c183b3ba12e36f61ce1c7553984c Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 29 Jan 2024 12:59:33 -0500 Subject: [PATCH] Clang format refresh (#2812) After noticing a number of commits with unrelated formatting changes, I think something was changed with clang-format at one point and we're seeing a number of unrelated changes. Doing a refresh can help avoid this. The changes made here came from ``` find lib -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find include -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find projects -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm ``` --- include/torch-mlir-c/Dialects.h | 2 +- .../Dialect/TMTensor/IR/TMTensorInterfaces.h | 2 +- .../Conversion/TorchOnnxToTorch/Patterns.h | 7 +- .../TorchToTosa/TosaLegalizeCommon.h | 20 +- .../TorchToTosa/TosaLegalizeUtils.h | 6 +- .../torch-mlir/Dialect/Torch/IR/TorchTraits.h | 6 +- .../Dialect/Torch/Transforms/Passes.h | 6 +- .../Dialect/Torch/Utils/TorchUpstream.h | 7 +- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 3 +- lib/CAPI/Dialects.cpp | 5 +- lib/Conversion/Passes.cpp | 4 +- .../TorchConversionToMLProgram.cpp | 3 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 105 +++-- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 91 ++-- lib/Conversion/TorchToArith/TorchToArith.cpp | 26 +- lib/Conversion/TorchToLinalg/DataMovement.cpp | 125 +++-- .../TorchToLinalg/IndirectDataMovement.cpp | 67 +-- lib/Conversion/TorchToLinalg/Linear.cpp | 70 +-- lib/Conversion/TorchToLinalg/Random.cpp | 1 - lib/Conversion/TorchToLinalg/Reduction.cpp | 22 +- .../TorchToLinalg/TensorConstructors.cpp | 407 ++++++++-------- .../TorchToLinalg/Uncategorized.cpp | 5 +- lib/Conversion/TorchToLinalg/Utils.cpp | 2 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 6 +- .../TorchToStablehlo/GatherScatter.cpp | 4 +- lib/Conversion/TorchToStablehlo/Pooling.cpp | 135 +++--- lib/Conversion/TorchToStablehlo/Reduction.cpp | 2 +- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 5 +- .../TorchToTosa/TosaLegalizeCommon.cpp | 20 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 41 +- lib/Dialect/TMTensor/Transforms/Bufferize.cpp | 3 +- lib/Dialect/Torch/IR/TorchDialect.cpp | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 74 +-- lib/Dialect/Torch/IR/TorchTypes.cpp | 8 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 206 ++++---- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 17 +- .../Torch/Transforms/InlineGlobalSlots.cpp | 13 +- .../Transforms/LowerToBackendContract.cpp | 17 +- .../ReifyAbstractInterpCalculationsUtils.cpp | 8 +- .../Transforms/SimplifyDtypeCalculations.cpp | 3 +- lib/Dialect/Torch/Utils/Utils.cpp | 8 +- .../IR/TorchConversionDialect.cpp | 6 +- .../Transforms/BackendTypeConversion.cpp | 95 ++-- .../Transforms/ConvertCustomQuantOp.cpp | 65 +-- .../VerifyLinalgOnTensorsBackendContract.cpp | 4 +- .../VerifyStablehloBackendContract.cpp | 3 +- .../csrc/base_lazy_backend/backend_impl.cpp | 72 +-- .../ltc/csrc/base_lazy_backend/backend_impl.h | 53 ++- .../ltc/csrc/base_lazy_backend/dynamic_ir.cpp | 29 +- .../mlir_lowering_context.cpp | 183 ++++--- .../base_lazy_backend/mlir_lowering_context.h | 64 ++- .../mlir_native_functions.cpp | 445 ++++++++++-------- .../ltc/csrc/base_lazy_backend/mlir_node.cpp | 81 ++-- .../ltc/csrc/base_lazy_backend/mlir_node.h | 48 +- .../base_lazy_backend/mlir_node_lowering.cpp | 178 +++---- .../base_lazy_backend/mlir_node_lowering.h | 6 +- .../base_lazy_backend/ops/device_data.cpp | 24 +- .../csrc/base_lazy_backend/ops/device_data.h | 20 +- .../csrc/base_lazy_backend/ops/generic.cpp | 8 +- .../ltc/csrc/base_lazy_backend/ops/generic.h | 12 +- .../ltc/csrc/base_lazy_backend/ops/index.cpp | 32 +- .../ltc/csrc/base_lazy_backend/ops/index.h | 32 +- .../ltc/csrc/base_lazy_backend/ops/ivalue.cpp | 8 +- .../ltc/csrc/base_lazy_backend/ops/ivalue.h | 10 +- .../ltc/csrc/base_lazy_backend/ops/split.cpp | 32 +- .../ltc/csrc/base_lazy_backend/ops/split.h | 30 +- .../ltc/csrc/base_lazy_backend/ops/to_copy.h | 73 +-- .../csrc/base_lazy_backend/ops/unbind_int.cpp | 12 +- .../csrc/base_lazy_backend/ops/unbind_int.h | 8 +- .../base_lazy_backend/shape_inference.cpp | 253 +++++----- .../ltc/csrc/base_lazy_backend/tensor.cpp | 10 +- projects/ltc/csrc/base_lazy_backend/tensor.h | 3 +- .../csrc/base_lazy_backend/utils/exception.h | 4 +- .../base_lazy_backend/utils/jit_utils.cpp | 16 +- .../csrc/base_lazy_backend/utils/jit_utils.h | 2 +- .../base_lazy_backend/utils/string_utils.h | 60 +-- .../csrc/base_lazy_backend/utils/sys_utils.h | 13 +- .../base_lazy_backend/utils/tensor_utils.cpp | 144 +++--- .../base_lazy_backend/utils/tensor_utils.h | 27 +- .../reference_lazy_backend/backend_impl.cpp | 32 +- .../reference_lazy_backend_pybind.cpp | 26 +- 81 files changed, 1972 insertions(+), 1815 deletions(-) diff --git a/include/torch-mlir-c/Dialects.h b/include/torch-mlir-c/Dialects.h index 99156c170..60f6ec1e5 100644 --- a/include/torch-mlir-c/Dialects.h +++ b/include/torch-mlir-c/Dialects.h @@ -22,4 +22,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Torch, torch); } #endif -#endif // TORCHMLIR_C_DIALECTS_H +#endif // TORCHMLIR_C_DIALECTS_H diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index f16b436c8..159bcea78 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -10,9 +10,9 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 44e33ab09..2df6f95c8 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -78,8 +78,8 @@ struct OpBinder { return failure(); return success(); } - - ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { + + ParseResult tensorOperandsList(llvm::SmallVectorImpl &values) { for (uint32_t i = 0; i < op->getNumOperands(); i++) { values.push_back(op->getOperand(i)); } @@ -97,7 +97,8 @@ struct OpBinder { return success(); } - ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, + int64_t idx) { if (idx >= op->getNumResults()) return failure(); auto t = toValidTensorType(op->getResult(idx).getType()); diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 16bf235de..44b9bbdde 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -37,33 +37,31 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } -// This specialization is for Div op. Unlike other binary ops, it doesn't support -// floating type. +// This specialization is for Div op. Unlike other binary ops, it doesn't +// support floating type. template <> tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs); std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value params_value, - Value index_value, - int32_t axis); + Operation *op, + Value params_value, + Value index_value, + int32_t axis); // Lowers torch.aten.Gather operators to a sequence of TOSA ops. // Revised from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type out_type, - Value params_value, - Value indices_value); +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type out_type, Value params_value, + Value indices_value); std::optional convertScatterNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue, Value fillValues); - // Lowers ReduceAll to a sequence of TOSA ops. std::optional convertReduceAllOp(PatternRewriter &rewriter, Operation *op, diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 14cf9cba7..44c033eb8 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -67,7 +67,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); // op. This allows shape inference during the framework to TOSA lowering. template TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, - Args &&... args) { + Args &&...args) { auto op = rewriter.create(loc, result_ty, args...); InferShapedTypeOpInterface shapeInterface = @@ -111,7 +111,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, template void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, - Type result_ty, Args &&... args) { + Type result_ty, Args &&...args) { auto result = CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); rewriter.replaceOp(op, result->getResults()); @@ -119,7 +119,7 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, // Get accumulator type for AvgPool2dOp. LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, - TypeAttr &accType); + TypeAttr &accType); } // namespace tosa } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h index 20f1bc109..271481f0a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h @@ -36,8 +36,7 @@ class HasValueSemantics // This is a weaker form of HasValueSemantics, since that trait also requires no // aliasing. That is, HasValueSemantics implies this trait. template -class ReadOnly - : public ::mlir::OpTrait::TraitBase {}; +class ReadOnly : public ::mlir::OpTrait::TraitBase {}; // If a Torch op has this trait, it means that the op is a "trailing underscore" // op variant that performs an in-place operation on its first argument. These @@ -62,7 +61,8 @@ class AllowsTypeRefinement // by the IValue importer. template class AllowedInModuleInitializer - : public ::mlir::OpTrait::TraitBase {}; + : public ::mlir::OpTrait::TraitBase {}; } // namespace OpTrait } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index fd7468847..71111c00c 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -61,7 +61,8 @@ struct TorchLoweringPipelineOptions Option extraLibrary{ *this, "extra-library", - llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")}; + llvm::cl::desc("Filename of MLIR module for splicing into the abstract " + "interpretation library.")}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -125,8 +126,7 @@ createSimplifyDtypeCalculationsPass(); std::unique_ptr> createDropAbstractInterpCalculationsPass(); -std::unique_ptr> -createEraseModuleInitializerPass(); +std::unique_ptr> createEraseModuleInitializerPass(); std::unique_ptr> createLowerToBackendContractPass(int maxIterations, bool decompose, diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index efb114fbf..043dd9254 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -140,12 +140,7 @@ enum Reduction { None, Mean, Sum, END }; // Source: // https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h //===----------------------------------------------------------------------===// -enum MemoryFormat { - Contiguous, - Preserve, - ChannelsLast, - ChannelsLast3d -}; +enum MemoryFormat { Contiguous, Preserve, ChannelsLast, ChannelsLast3d }; //===----------------------------------------------------------------------===// // Possible values for `layout` argument in PyTorch ops that support it. diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index b5c815ca7..beafe7d21 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -121,8 +121,7 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, // Helper to create a tensor filled with the given scalar. Scalar would be // converted the to the element type of the given tensor type. Value createInitTensor(PatternRewriter &rewriter, Location loc, - BaseTensorType resultType, Value scalar, - Value sizeList); + BaseTensorType resultType, Value scalar, Value sizeList); // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` // would be converted to the element type of the given `inputType`. diff --git a/lib/CAPI/Dialects.cpp b/lib/CAPI/Dialects.cpp index 06be821c0..048e37e08 100644 --- a/lib/CAPI/Dialects.cpp +++ b/lib/CAPI/Dialects.cpp @@ -9,7 +9,8 @@ #include "torch-mlir-c/Dialects.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "mlir/CAPI/Registration.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, mlir::torch::Torch::TorchDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, + mlir::torch::Torch::TorchDialect) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index b9af2afa3..6d8adbaa1 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -30,6 +30,4 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { - ::registerPasses(); -} +void mlir::torch::registerConversionPasses() { ::registerPasses(); } diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index eab81c2be..6a00e5190 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -82,7 +82,8 @@ public: // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = + rewriter.create(loc, seed, globalVar, ValueRange()); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 4ee71af3f..df20a8351 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -29,7 +29,8 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { - patterns.onOp("HardSigmoid", 6, + patterns.onOp( + "HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; @@ -39,8 +40,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); - - // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) + + // HardSigmoid computes the following expression: + // max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); @@ -51,7 +53,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( - binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); + binder.getLoc(), resultType, tensorOperand, constBeta, + /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( @@ -100,7 +103,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); - }); + }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -109,9 +112,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); - return success(); + return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -126,7 +129,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp("MatMul", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { + [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -206,20 +209,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp("Mul", 7, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { @@ -332,41 +335,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Max", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || - operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp(binder.op, result.getDefiningOp()); - return success(); - }); - patterns.onOp("Min", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || - operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp( - binder.op, result.getDefiningOp()); - return success(); - }); + patterns.onOp( + "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp( + "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -693,7 +693,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); - Value cstFalse = rewriter.create(binder.getLoc(), false); + Value cstFalse = + rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); @@ -903,7 +904,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); + binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6569e3abc..87f68375a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -42,56 +42,63 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { - patterns.onOp("QuantizeLinear", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperands(operands, 3) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "QuantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); - Value operand = operands[0]; - Value scale = operands[1]; - Value zeropoint = operands[2]; + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; - auto scaleTy = scale.getType().dyn_cast(); - if (!scaleTy || !scaleTy.hasSizes()) - return rewriter.notifyMatchFailure(binder.op, - "requires known rank"); - if (!resultType.hasDtype()) - return rewriter.notifyMatchFailure( - binder.op, "requires known result dtype"); + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure(binder.op, + "requires known result dtype"); - if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); + if (scaleTy.getSizes().size() == 0) { + Type qTy = resultType.getDtype(); - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { - return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); - } + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } - auto qTensorTy = rewriter.getType(resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); + auto qTensorTy = rewriter.getType( + resultType.getOptionalSizes(), qTy); + auto torchqTy = Torch::getScalarTypeForType(qTy); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); - scale = rewriter.create(binder.getLoc(), rewriter.getType(), scale); - zeropoint = rewriter.create(binder.getLoc(), rewriter.getType(), zeropoint); + scale = rewriter.create( + binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create( + binder.getLoc(), rewriter.getType(), zeropoint); - auto quantize = rewriter.create(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); - rewriter.replaceOpWithNewOp(binder.op, resultType, quantize); - return success(); - } + auto quantize = rewriter.create( + binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); + } - return failure(); - }); + return failure(); + }); patterns.onOp( "QLinearMatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1245,7 +1252,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } // Convert dynamic shape dimension. - for (unsigned i = 0; i < shape.size(); i++){ + for (unsigned i = 0; i < shape.size(); i++) { if (shape[i] == ShapedType::kDynamic) shape[i] = Torch::kUnknownSize; } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index e1e53acb2..d2000d7fc 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -43,7 +43,8 @@ public: LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); + auto rank = + rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -74,7 +75,8 @@ public: matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); + rewriter.template replaceOpWithNewOp(op, adaptor.getA(), + adaptor.getB()); return success(); } }; @@ -112,10 +114,10 @@ public: typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value a = - convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); - Value b = - convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); + Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), + rewriter.getF64Type()); + Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), + rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } @@ -176,15 +178,16 @@ public: unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); auto shapedType = - RankedTensorType::get(type.getShape(), builtinTensorElemTy); + RankedTensorType::get(type.getShape(), builtinTensorElemTy); auto rawData = elements.getRawData(); - DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( - shapedType, rawData); + DenseElementsAttr newAttr = + DenseElementsAttr::getFromRawBuffer(shapedType, rawData); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } } - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = + op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -360,7 +363,8 @@ public: // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToArith : public ConvertTorchToArithBase { +class ConvertTorchToArith + : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index e96d65970..add32ff05 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -110,22 +110,32 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, // Example: // input = tensor([[[0., 1., 2., 3.], // [4., 5., 6., 7.]]]) -// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1 -// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], -// [7., 6., 5., 4., 5., 6., 7., 6.]]]) -// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension -// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension. -// The last dimension of the result tensor should be last dimension of input tensor + -// left padding size + right padding size. INitialize result tensor to all zeros -// b) Setup affine map to take slice from input tensor of size left padding starting from -// second column onwards as first column is reflection boundary +// torch.ops.aten.reflection_pad1d(input, (3,1)); +// padding_left = 3, +// padding_right = 1 +// output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], +// [7., 6., 5., 4., 5., 6., 7., 6.]]]) +// Checks: 1) Each of padding_left and padding_right must be non-negative and +// less than the size of the last dimension. +// Implementation: a) Construct a result tensor of +// shape of input tensor except for the last dimension. +// The last dimension of the result tensor should be last +// dimension of input tensor + left padding size + right +// padding size. Initialize result tensor to all zeros +// b) Setup affine map to take slice from input tensor of size +// left padding starting from +// second column onwards as first column is reflection +// boundary // c) Reflect the affine map to have resultant slice reflected // d) Take the slice and write from begining in result tensor // e) write the original tensor next into result tensor -// f) Setup affine map to take slice from input tensor of right padding size ending -// at second last column as last column is reflection boundary for right padding +// f) Setup affine map to take slice from input tensor of right +// padding size ending +// at second last column as last column is reflection +// boundary for right padding // g) Reflect the affine map to have resultant slice reflected -// h) Take the slice and write from left padding size + orignal tensor last dim size +// h) Take the slice and write from left padding size + orignal +// tensor last dim size // into result tensor // Uses the ideas/code used for AtenReflectionPad2dOp namespace { @@ -138,7 +148,7 @@ public: ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - + SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( @@ -158,55 +168,68 @@ public: return rewriter.create(loc, x, y); }; - enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2}; + enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 }; Value input = adaptor.getSelf(); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); auto inputType = llvm::cast(input.getType()); - auto outputType = llvm::cast(getTypeConverter()->convertType(op->getResult(0).getType())); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); int64_t lastDim = numDims - 1; SmallVector inputShape = getTensorSizes(rewriter, loc, input); - Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4 + Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, + // inputShape[2] will give 4 Value tileWidth[3], extractOffset[3], insertOffset[3]; - - tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); - tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); + + tileWidth[PAD_LEFT] = + getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); + tileWidth[PAD_RIGHT] = + getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); tileWidth[PAD_CENTER] = lastDimSize; extractOffset[PAD_LEFT] = one; - // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right - // lasDimSize - (tileWidth[PAD_RIGHT] + one) - extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); + // The offset for the right hand padding "bar" is: + // [right] lastDimSize - (tileWidth[PAD_RIGHT] + one) + extractOffset[PAD_RIGHT] = + createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); extractOffset[PAD_CENTER] = zero; insertOffset[PAD_LEFT] = zero; insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; - SmallVector resultShape{inputShape}; - // Result's last dimension will have shape lastDimSize + left padding size + right padding size - resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); - Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); + // Result's last dimension will have size: + // lastDimSize + left padding size + right padding size + resultShape[lastDim] = + createIAdd(resultShape[lastDim], + createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); - // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor - // for which the corresponding dimension has a statically known size - auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { + // Helper to reflect/reverse the i-th dimension of an affine map without + // symbols. This only works if applied on a tensor for which the + // corresponding dimension has a statically known size + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { AffineExpr d = map.getResult(i); - return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3 + return map.replace(d, size - d - 1, numDims, + 0); // left reflect for (3,1) on input shape (1,2,4). + // size = 3, lastDim=2, numDims=3 }; - SmallVector iteratorTypes{numDims, utils::IteratorType::parallel}; + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto addTileToResult = [&](PadLocation padPosition) { - // Create the tile by extracting a slice from the input tensor. + // Create the tile by extracting a slice from the input tensor. SmallVector extractShape{inputShape}; extractShape[lastDim] = tileWidth[padPosition]; SmallVector extractOffsets(numDims, zero); @@ -214,35 +237,39 @@ public: Value tile = rewriter.create( loc, input, extractOffsets, extractShape, allOneStrides); - auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); - // Setup the affine map function to resverse the tile along the horizontal for left and right slices - if(padPosition < PAD_CENTER) { - inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); - // Take reflected slice as per inputMap - tile = rewriter.create(loc, llvm::cast(tile.getType()), tile, - tile, ArrayRef({inputMap, idMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }).getResult(0); + // Setup the affine map function to resverse the tile along the horizontal + // for left and right slices + if (padPosition < PAD_CENTER) { + inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); + // Take reflected slice as per inputMap + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); } // Insert the tile in the resultTensor SmallVector insertOffsets(numDims, zero); insertOffsets[lastDim] = insertOffset[padPosition]; - resultTensor = rewriter.create(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); }; - - if(padInts[PAD_LEFT] > 0) - addTileToResult(PAD_LEFT); - if(padInts[PAD_RIGHT] > 0) - addTileToResult(PAD_RIGHT); + + if (padInts[PAD_LEFT] > 0) + addTileToResult(PAD_LEFT); + if (padInts[PAD_RIGHT] > 0) + addTileToResult(PAD_RIGHT); addTileToResult(PAD_CENTER); rewriter.replaceOpWithNewOp(op, outputType, resultTensor); return success(); } }; -} +} // namespace namespace { diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index f9ee56070..bfbe45afe 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -79,7 +79,8 @@ public: int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -248,9 +249,9 @@ public: } if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { - return rewriter.notifyMatchFailure( - op, - "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); + return rewriter.notifyMatchFailure(op, + "Unimplemented: Mean and Max mode are " + "not supported yet for EmbeddingBag."); } bool isSparse; @@ -291,28 +292,28 @@ public: SmallVector indicesExpr; indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); auto indicesIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - indicesExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + indicesExpr, context); SmallVector offsetsExpr; offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); auto offsetIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - offsetsExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + offsetsExpr, context); SmallVector outputExpr; outputExpr.push_back(mlir::getAffineDimExpr(0, context)); outputExpr.push_back(mlir::getAffineDimExpr(2, context)); auto outputIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - outputExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + outputExpr, context); SmallVector indexingMaps = { - indicesIndexingMap, - offsetIndexingMap, - outputIndexingMap, + indicesIndexingMap, + offsetIndexingMap, + outputIndexingMap, }; // Reduce along the indices dim @@ -326,15 +327,15 @@ public: Value indicesLength; if (!discardLastOffset) { SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), - embeddingDim}; + embeddingDim}; initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); offsetsLength = getDimOp(rewriter, loc, offsets, 0); indicesLength = getDimOp(rewriter, loc, indices, 0); } else { return rewriter.notifyMatchFailure( - op, "Unimplemented: include last offset is not yet " - "supported for EmbeddingBag."); + op, "Unimplemented: include last offset is not yet " + "supported for EmbeddingBag."); } Value embeddingBagResult = @@ -351,10 +352,10 @@ public: Value indexI = b.create(loc, /*value=*/0); Value indexIToInt = castIndexToInt64(b, loc, indexI); - Value one = getConstant( - b, loc, 1, - mlir::IntegerType::get(getContext(), 64, - IntegerType::Signless)); + Value one = + getConstant(b, loc, 1, + mlir::IntegerType::get( + getContext(), 64, IntegerType::Signless)); Value offsetIndexPlusOneInt = b.create(loc, indexIToInt, one); @@ -378,7 +379,7 @@ public: loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); Value offsetLessThanOrEqualToIndicesIndex = b.create(loc, offsetLessThanIndicesIndex, - offsetEqualToIndicesIndex); + offsetEqualToIndicesIndex); Value indicesIndexLessThanNextOffset = b.create(loc, arith::CmpIPredicate::slt, @@ -393,19 +394,18 @@ public: castIntToIndex(b, loc, indexInIndices)); indexIntoWeight.push_back( b.create(loc, /*value=*/2)); - Value weightElem = b.create( - loc, weight, indexIntoWeight); + Value weightElem = + b.create(loc, weight, indexIntoWeight); - Value addResult = b.create(loc, weightElem, - initTensorElem); - Value select = - b.create(loc, indicesIndexWithinBounds, - addResult, initTensorElem); + Value addResult = + b.create(loc, weightElem, initTensorElem); + Value select = b.create( + loc, indicesIndexWithinBounds, addResult, initTensorElem); b.create(loc, select); - }) - .getResult(0); + }) + .getResult(0); - // cast outputType. + // cast outputType. auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); Value castedEmbeddingBagResult = rewriter.create(loc, restulType0, embeddingBagResult); @@ -439,7 +439,7 @@ public: rewriter.create(loc, resultType3, indicesOut); rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, - castedBagSizeResult, castedMaxIndices}); + castedBagSizeResult, castedMaxIndices}); return success(); } @@ -552,7 +552,8 @@ static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, // e.g. x: [2, 3] // x[[4], [6, 1]] -> x[6, 4] namespace { -class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern { +class ConvertAtenIndexTensorHackedTwinOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 6d0d72075..c0585df0b 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -165,7 +165,8 @@ public: Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); - auto selfRank = adaptor.getSelf().getType().cast().getRank(); + auto selfRank = + adaptor.getSelf().getType().cast().getRank(); Type elementType = adaptor.getSelf().getType().cast().getElementType(); Value c1 = @@ -535,7 +536,8 @@ public: RankedTensorType lhsType = lhs.getType().cast(); RankedTensorType rhsType = rhs.getType().cast(); Type newResultType = getTypeConverter()->convertType(op.getType()); - Type resultElementType = newResultType.cast().getElementType(); + Type resultElementType = + newResultType.cast().getElementType(); Type lhsElementType = lhsType.cast().getElementType(); Type rhsElementType = rhsType.cast().getElementType(); @@ -547,13 +549,15 @@ public: // Convert the inputs element type equivalent to the result' element type. if (lhsElementType != rhsElementType) { if (lhsElementType != resultElementType) { - // True if the lhs element type is not equal to the result' element type. - lhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, lhs, resultElementType); + // True if the lhs element type is not equal to the result' element + // type. + lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs, + resultElementType); } else { - // True if the rhs element type is not equal to the result' element type. - rhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, rhs, resultElementType); + // True if the rhs element type is not equal to the result' element + // type. + rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs, + resultElementType); } } @@ -571,7 +575,8 @@ public: checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, + resultElementType); Value bmm = rewriter @@ -634,7 +639,8 @@ public: return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); @@ -838,8 +844,10 @@ public: Value conv; // the code so far is able to respect all numSpacialDims - // the code below this point is numSpacialDims specific and groupSize specific - // TODO: factor out the above code into a helper function, and then separate convolution into: + // the code below this point is numSpacialDims specific and groupSize + // specific + // TODO: factor out the above code into a helper function, and then separate + // convolution into: // - grouped 1d-3d // - ungrouped 1d-3d if (groupSize == 1) { @@ -854,20 +862,20 @@ public: .getResult(0); break; case 2: - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); break; case 3: - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); break; default: return rewriter.notifyMatchFailure( @@ -877,7 +885,7 @@ public: rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } else { - if(numSpacialDims != 2) + if (numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); @@ -901,11 +909,11 @@ public: loc, collapsedType, weight, collapsedDims); conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); @@ -979,7 +987,7 @@ public: conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); - Type newResultType = getTypeConverter()->convertType(op.getType()); + Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 26a2c0ea5..35c349a6a 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -194,7 +194,6 @@ public: }; } // namespace - void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 289851cd3..da5ee799a 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -100,11 +100,11 @@ public: if (integerTy.isUnsigned()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires input element type " - "to be signed in case of integer"); + "to be signed in case of integer"); } else { return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires Float or Integer " - "input element type"); + "input element type"); } } @@ -144,8 +144,7 @@ public: } Value filledTensorVal = - rewriter.create(loc, fillValue, initTensorVal) - .result(); + rewriter.create(loc, fillValue, initTensorVal).result(); // Create the affine expressions that will be used to // iterate over the input and output tensors. @@ -186,7 +185,7 @@ public: Value resultVal, predicate; if (inElementType.isa()) { - arith::CmpFPredicate predType; + arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; resultVal = rewriter.create( @@ -198,7 +197,7 @@ public: } predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + newValue, oldValue); } else { arith::CmpIPredicate predType; if (isMax) { @@ -220,8 +219,8 @@ public: }); // This cast is required to fix the shape in the case of keepDim=True - Value valuesCast = rewriter.create( - loc, valResultType, linalgOp.getResult(0)); + Value valuesCast = rewriter.create(loc, valResultType, + linalgOp.getResult(0)); Value idxCast = rewriter.create(loc, idxResultType, linalgOp.getResult(1)); rewriter.replaceOp(op, {valuesCast, idxCast}); @@ -345,7 +344,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value self = convertScalarToDtype(b, loc, elem, resultElementType); auto abs = b.create(loc, self); AtenLinalgVectorNormOp::Adaptor adaptor(operands); - Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); + Value ord = + convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); } else if (isa(op)) { @@ -427,8 +427,8 @@ private: opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the - // input tensor. + // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the + // dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 6afae47c1..2b8eac494 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -83,209 +83,224 @@ public: namespace { - // Lower aten.replication_pad2d operator into a sequence of - // tensor.extract_slice and tensor.concat operations. +// Lower aten.replication_pad2d operator into a sequence of +// tensor.extract_slice and tensor.concat operations. - class ConvertAtenReplicationPad2dOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); +class ConvertAtenReplicationPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); - Location loc = op->getLoc(); - Value input = adaptor.getSelf(); - auto inputType = llvm::cast(input.getType()); - int64_t inputRank = inputType.getRank(); - unsigned numDims = inputType.getRank(); - assert(numDims >= 2 && "Not enough input dimensions"); + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); - SmallVector padInts; - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) - return rewriter.notifyMatchFailure( + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( op, "only support constant int pad ranges"); - uint64_t padRank = padInts.size() / 2; - if (padRank * 2 != padInts.size()) - return rewriter.notifyMatchFailure(op, "pad range size is not even"); - if (inputRank < 0 || padRank > (uint64_t)inputRank) - return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (inputRank < 0 || padRank > (uint64_t)inputRank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - int64_t hDim = numDims - 1; - int64_t vDim = numDims - 2; - Value hDimSize = inputShape[hDim]; - Value vDimSize = inputShape[vDim]; + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; - enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; - enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, }; - // vTile denotes the vertical size of the tile - // hTile denotes the horizontal size of the tile - // The padding results are composed of following tiles: - // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] - // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], vTile[VCENTER]hTile[RIGHT] - // vTile[BOTTOM]hTile[LEFT], vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] - // vTile[VCENTER]hTile[HCENTER] is the original input tensor - Type indexType = rewriter.getIndexType(); - Value vTile[3]; - Value hTile[3]; - vTile[VCENTER] = vDimSize; - hTile[HCENTER] = hDimSize; - vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); - vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); - hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); - hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); + enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; + enum tileVLoc { + TOP = 0, + VCENTER = 2, + BOTTOM = 1, + }; + // vTile denotes the vertical size of the tile + // hTile denotes the horizontal size of the tile + // The padding results are composed of following tiles: + // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] + // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], + // vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT], + // vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] + // vTile[VCENTER]hTile[HCENTER] is the original input tensor + Type indexType = rewriter.getIndexType(); + Value vTile[3]; + Value hTile[3]; + vTile[VCENTER] = vDimSize; + hTile[HCENTER] = hDimSize; + vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); + vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); + hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); + hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); - bool hasLeftPadding = false; - bool hasRightPadding = false; - bool hasTopPadding = false; - bool hasBottomPadding = false; + bool hasLeftPadding = false; + bool hasRightPadding = false; + bool hasTopPadding = false; + bool hasBottomPadding = false; - for (auto i: {TOP, VCENTER, BOTTOM}){ - for (auto j: {LEFT, HCENTER, RIGHT}) { - auto constVtile{ - mlir::dyn_cast(vTile[i].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; + for (auto i : {TOP, VCENTER, BOTTOM}) { + for (auto j : {LEFT, HCENTER, RIGHT}) { + auto constVtile{ + mlir::dyn_cast(vTile[i].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; - auto constHtile{ - mlir::dyn_cast(hTile[j].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; - auto vSize = constVtile.getInt(); - auto hSize = constHtile.getInt(); + auto constHtile{ + mlir::dyn_cast(hTile[j].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + auto vSize = constVtile.getInt(); + auto hSize = constHtile.getInt(); - if ((i == TOP) && (vSize > 0)) - hasTopPadding = true; - if ((i == BOTTOM) && (vSize > 0)) - hasBottomPadding = true; - if ((j == LEFT) && (hSize > 0)) - hasLeftPadding = true; - if ((j == RIGHT) && (hSize > 0)) - hasRightPadding = true; - } + if ((i == TOP) && (vSize > 0)) + hasTopPadding = true; + if ((i == BOTTOM) && (vSize > 0)) + hasBottomPadding = true; + if ((j == LEFT) && (hSize > 0)) + hasLeftPadding = true; + if ((j == RIGHT) && (hSize > 0)) + hasRightPadding = true; } - - auto createSub = [&](Value x, Value y) { - return rewriter.create(loc, x, y); - }; - - // Extract left and right pad tiles. - Value zero = getConstant(rewriter, loc, 0, indexType); - Value one = getConstant(rewriter, loc, 1, indexType); - Value hDimSizeMinusOne = createSub(hDimSize, one); - Value vDimSizeMinusOne = createSub(vDimSize, one); - SmallVector allOneStrides(numDims, one); - - SmallVector extractOffsetsLT(numDims, zero); - extractOffsetsLT[hDim] = zero; - extractOffsetsLT[vDim] = zero; - SmallVector extractShapeLR(numDims, one); - extractShapeLR[hDim] = one; - extractShapeLR[vDim] = vDimSize; - - SmallVector extractOffsetsRight(numDims, zero); - extractOffsetsRight[hDim] = hDimSizeMinusOne; - extractOffsetsRight[vDim] = zero; - - SmallVector extractOffsetsBottom(numDims, zero); - extractOffsetsBottom[hDim] = zero; - extractOffsetsBottom[vDim] = vDimSizeMinusOne; - - SmallVector extractShapeTB(numDims, one); - extractShapeTB[hDim] = hDimSize; - extractShapeTB[vDim] = one; - - SmallVector tensorsLeft; - SmallVector tensorsRight; - SmallVector tensorsCenter; - Value centerTile; - SmallVector tensorsRes; - - if (hasLeftPadding) { - Value vCenterLeftSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); - Value vLeftSlice = vCenterLeftSlice; - if (hasTopPadding) { - Value topLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, zero}); - //pad vCenterLeftSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; - vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); - } - if (hasBottomPadding) { - Value bottomLeftValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); - - //pad vLeftSlice at the bottom - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; - vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); - } - for (auto i=0; i(loc, 3, tensorsLeft); - tensorsRes.push_back(leftPadTile); - } - if (hasTopPadding) { - Value topHcenterSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[2]; ++i) { - tensorsCenter.push_back(topHcenterSlice); - } - } - tensorsCenter.push_back(input); - if (hasBottomPadding) { - Value bottomHcenterSlice = rewriter.create( - loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[3]; ++i) { - tensorsCenter.push_back(bottomHcenterSlice); - } - } - centerTile = rewriter.create(loc, 2, tensorsCenter); - tensorsRes.push_back(centerTile); - - if (hasRightPadding) { - Value vCenterRightSlice = rewriter.create( - loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); - Value vRightSlice = vCenterRightSlice; - if (hasTopPadding) { - Value topRightValue = rewriter.create (loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); - - //pad vCenterRightSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; - vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); - } - if (hasBottomPadding) { - Value bottomRightValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); - - // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; - vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); - } - for (auto i=0; i(loc, 3, tensorsRight); - tensorsRes.push_back(rightPadTile); - } - Value resTensor = rewriter.create(loc, 3, tensorsRes); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, resTensor); - return success(); } - }; -} + + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Extract left and right pad tiles. + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + Value hDimSizeMinusOne = createSub(hDimSize, one); + Value vDimSizeMinusOne = createSub(vDimSize, one); + SmallVector allOneStrides(numDims, one); + + SmallVector extractOffsetsLT(numDims, zero); + extractOffsetsLT[hDim] = zero; + extractOffsetsLT[vDim] = zero; + SmallVector extractShapeLR(numDims, one); + extractShapeLR[hDim] = one; + extractShapeLR[vDim] = vDimSize; + + SmallVector extractOffsetsRight(numDims, zero); + extractOffsetsRight[hDim] = hDimSizeMinusOne; + extractOffsetsRight[vDim] = zero; + + SmallVector extractOffsetsBottom(numDims, zero); + extractOffsetsBottom[hDim] = zero; + extractOffsetsBottom[vDim] = vDimSizeMinusOne; + + SmallVector extractShapeTB(numDims, one); + extractShapeTB[hDim] = hDimSize; + extractShapeTB[vDim] = one; + + SmallVector tensorsLeft; + SmallVector tensorsRight; + SmallVector tensorsCenter; + Value centerTile; + SmallVector tensorsRes; + + if (hasLeftPadding) { + Value vCenterLeftSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); + Value vLeftSlice = vCenterLeftSlice; + if (hasTopPadding) { + Value topLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, zero}); + // pad vCenterLeftSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); + } + if (hasBottomPadding) { + Value bottomLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + + // pad vLeftSlice at the bottom + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); + } + for (auto i = 0; i < padInts[0]; ++i) { + tensorsLeft.push_back(vLeftSlice); + } + Value leftPadTile = + rewriter.create(loc, 3, tensorsLeft); + tensorsRes.push_back(leftPadTile); + } + if (hasTopPadding) { + Value topHcenterSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[2]; ++i) { + tensorsCenter.push_back(topHcenterSlice); + } + } + tensorsCenter.push_back(input); + if (hasBottomPadding) { + Value bottomHcenterSlice = rewriter.create( + loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[3]; ++i) { + tensorsCenter.push_back(bottomHcenterSlice); + } + } + centerTile = rewriter.create(loc, 2, tensorsCenter); + tensorsRes.push_back(centerTile); + + if (hasRightPadding) { + Value vCenterRightSlice = rewriter.create( + loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); + Value vRightSlice = vCenterRightSlice; + if (hasTopPadding) { + Value topRightValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); + + // pad vCenterRightSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); + } + if (hasBottomPadding) { + Value bottomRightValue = rewriter.create( + loc, input, + ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + + // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, + bottomRightValue); + } + for (auto i = 0; i < padInts[1]; ++i) { + tensorsRight.push_back(vRightSlice); + } + Value rightPadTile = + rewriter.create(loc, 3, tensorsRight); + tensorsRes.push_back(rightPadTile); + } + Value resTensor = rewriter.create(loc, 3, tensorsRes); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, resTensor); + return success(); + } +}; +} // namespace namespace { // Converts constant tensor allocation like ops. @@ -348,8 +363,8 @@ public: // Create an uninitialized tensor of `resultSize` shape and fill it with // value `fillVal`. Value constVal = getConstant(rewriter, loc, fillVal, resultElementType); - Value outputTensor = - createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal); + Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex, + resultElementType, constVal); rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } @@ -384,7 +399,8 @@ public: // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; - if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) + if (!matchPattern(op.getMemoryFormat(), + m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); @@ -495,7 +511,8 @@ public: typeConverter->convertType(op->getResult(0).getType()) .cast(); Type dtype = resultType.getElementType(); - Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); + Value start = + convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype); Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 9ff4c6374..543179793 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -426,10 +426,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)){ + if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( - loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); + loc, + b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); return createEqual(b, loc, abs.getType(), abs, infinity); } if (isa(op)) { diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 8bff5034c..0d62010d7 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,13 +7,13 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f0dc4aaf2..00c9fcd7b 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -923,8 +923,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getA().getType().template cast().getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - auto result = - rewriter.create(loc, adaptor.getA()); + auto result = rewriter.create(loc, adaptor.getA()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); @@ -1797,8 +1796,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) + patterns.add>(typeConverter, context) INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 9c8123bfd..d2b0450cd 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -30,8 +30,8 @@ using namespace mlir::torch::torch_to_stablehlo; namespace { static Value createInitialValueForGatherScatterOp(Operation *op, - RankedTensorType constType, - PatternRewriter &rewriter) { + RankedTensorType constType, + PatternRewriter &rewriter) { auto elementTy = constType.getElementType(); if (isa(op)) { if (elementTy.isa()) { diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index e90f231c7..7ef69ae67 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -35,7 +35,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -373,7 +374,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -388,45 +388,45 @@ public: Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); RankedTensorType outTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + ->convertType(op.getType()) + .template cast(); auto outShape = outTy.getShape(); - if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError( + "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; bool countIncludePad = true; if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); } if (!(matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)))) { - return rewriter.notifyMatchFailure( - op, "non-const bool count_include_pad unsupported!"); + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); } if constexpr (std::is_same()) { - if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) - return rewriter.notifyMatchFailure( - op, "only None divisor_override supported for now!"); + if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); } // Prepend 1 to kernelSize, stride, dilation until they are of same rank @@ -437,33 +437,35 @@ public: SmallVector stablehloPadding(inputRank * 2, 0); std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - Dim); + stablehloStride.begin() + inputRank - Dim); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - Dim); + stablehloKernelSize.begin() + inputRank - Dim); if (Dim == 1) { - stablehloPadding[stablehloPadding.size() - 2] = padding[0]; - stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; } else { - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; } - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), + RankedTensorType::get( + {static_cast(stablehloKernelSize.size())}, + rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), + rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), + rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( @@ -485,31 +487,31 @@ public: auto secondArg = *sumBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sumBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { - Value divisor; - if (Dim == 1) { - divisor = - hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) - .value(); - } else { - divisor = hlo::getConstTensor( - rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) - .value(); - } - divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); - DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); - return success(); + Value divisor; + if (Dim == 1) { + divisor = + hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) + .value(); + } else { + divisor = hlo::getConstTensor( + rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) + .value(); + } + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); } // Use another mhlo.ReduceWindowOp to get the divisor @@ -518,8 +520,8 @@ public: windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); const auto &options = ConvertAtenOp::getOptions(); - auto inputShapeVec = - *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); @@ -544,23 +546,20 @@ public: secondArg = *sizeBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sizeBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); - } - }; -} - +} // namespace // AtenCumsumOp template <> @@ -660,10 +659,10 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); -#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ +#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context, options) + patterns.add>(typeConverter, context, \ + options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); #undef INSERT_ATEN_AVGPOOL_PATTERN diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9..f495aa395 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -16,13 +16,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6..507821dee 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -22,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include using namespace mlir; @@ -403,7 +403,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index acaa60ffc..5ed681c6e 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -131,10 +131,10 @@ tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, } std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value paramsValue, - Value indexValue, - int32_t axis) { + Operation *op, + Value paramsValue, + Value indexValue, + int32_t axis) { // For easy understanding of this algorithm, the following comments are with // an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2, // !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> @@ -210,9 +210,9 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, // Lowers Gather operators to a sequence of TOSA ops. // taken from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type outType, - Value paramsValue, Value indicesValue) { +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type outType, Value paramsValue, + Value indicesValue) { auto resultType = outType.dyn_cast(); auto paramsType = paramsValue.getType().dyn_cast(); auto indicesType = indicesValue.getType().dyn_cast(); @@ -683,7 +683,6 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, .getResult(); } - // Common function for lowering reduce operations to TOSA ops. template std::optional convertReduceOpCommon( @@ -721,9 +720,8 @@ std::optional convertReduceOpCommon( auto axis_attr = rewriter.getI32IntegerAttr(axis_val); shape_vec[axis_val] = 1; - RankedTensorType reduce_type = RankedTensorType::get( - shape_vec, - reduce_element_type); + RankedTensorType reduce_type = + RankedTensorType::get(shape_vec, reduce_element_type); auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index b8f719792..781a5912d 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -176,7 +176,8 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // Default template creates a constant tensor in T. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape, std::optional dtype) { + ArrayRef vec, ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -188,7 +189,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, } auto width = sizeof(T) * 8; - if constexpr(std::is_same_v) + if constexpr (std::is_same_v) width = 1; auto const_type = @@ -199,7 +200,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -209,7 +210,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -228,7 +230,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -238,7 +240,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -256,7 +259,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -347,23 +350,17 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { } // Template instantiation -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType) { diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 64352ad1d..1e8c91e8a 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -87,7 +87,8 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, ValueRange outputs) { SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - return cast(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); + return cast( + tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); } /// Generic conversion pattern that matches any TMTensorOp. This avoids template diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 5c90df8e6..e7fcbb434 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -157,7 +157,7 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, intValue); } } - + if (type.isa()) { return builder.create(loc, value.cast()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e63a4e376..4af9bcfc1 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -203,8 +203,8 @@ static Value getScalarFloatValue(Value input, Location loc, //===----------------------------------------------------------------------===// LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto func = - symbolTable.lookupNearestSymbolFrom(*this, getFunctionAttr()); + auto func = symbolTable.lookupNearestSymbolFrom( + *this, getFunctionAttr()); if (!func) return emitError() << "'@" << getFunction() << "' does not reference a valid function"; @@ -453,11 +453,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // If the condition is constant, delete the dead branch and inline the live // branch. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { - auto constantBool = op.getCondition().getDefiningOp(); + auto constantBool = + op.getCondition().getDefiningOp(); if (!constantBool) return rewriter.notifyMatchFailure(op, "non-constant condition"); - replaceOpWithRegion( - rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion()); + replaceOpWithRegion(rewriter, op, + constantBool.getValue() ? op.getThenRegion() + : op.getElseRegion()); return success(); }); // If the thenRegion and elseRegion yield the same Value's, then use those @@ -515,14 +517,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, continue; newResultTypes.push_back(op->getResult(i).getType()); } - auto newIf = - rewriter.create(op->getLoc(), newResultTypes, op.getCondition()); + auto newIf = rewriter.create(op->getLoc(), newResultTypes, + op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); - newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase); - newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase); + newIf.getThenRegion().front().getTerminator()->eraseOperands( + resultsToErase); + newIf.getElseRegion().front().getTerminator()->eraseOperands( + resultsToErase); SmallVector replacementValues; for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) { if (resultsToErase[i]) @@ -548,8 +552,8 @@ void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return failure(); if (value) { - rewriter.eraseOp(op); - return success(); + rewriter.eraseOp(op); + return success(); } // Even if the condition is statically false, the assert might never be // executed. @@ -898,10 +902,10 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, auto rhs = op.getOther(); auto getRhsDevice = rewriter.create(op.getLoc(), rhs); auto getRhsDtype = rewriter.create(op.getLoc(), rhs); - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, getRhsDevice.getResult(), - getRhsDtype.getResult(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, getRhsDevice.getResult(), + getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), + op.getMemoryFormat()); return success(); }); } @@ -996,7 +1000,7 @@ void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // `aten.max.other` -> `aten.maximum` patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getOther()); + op.getOther()); return success(); }); } @@ -1934,7 +1938,7 @@ void Torch::ConstantFloatOp::getAsmResultNames( // float string representation). SmallVector buf; getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, - /*TruncateZero=*/false); + /*TruncateZero=*/false); auto isValidMLIRIdentifierChar = [](char c) { return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || c == '-'; @@ -2045,7 +2049,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( // compiler treat the size as having value semantics? // There's a small number of such ops, and they are marked as `inplace_view` // in PyTorch's `native_functions.yaml` file. - rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), op.getIdx()); + rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), + op.getIdx()); return success(); }); } @@ -2073,11 +2078,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { - auto lhsListConstruct = op.getA().getDefiningOp(); + auto lhsListConstruct = + op.getA().getDefiningOp(); if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) return failure(); - auto rhsListConstruct = op.getB().getDefiningOp(); + auto rhsListConstruct = + op.getB().getDefiningOp(); if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) return failure(); @@ -2195,7 +2202,8 @@ LogicalResult PrimTupleConstructOp::verify() { void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2245,7 +2253,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns( void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2400,9 +2409,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenAliasOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { - return getOperand(); -} +OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // AtenFloordivIntOp @@ -2481,14 +2488,12 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - int64_t start, end, step; - if (matchPattern(getStart(), m_TorchConstantInt(&start)) && - matchPattern(getEnd(), m_TorchConstantInt(&end)) && - matchPattern(getStep(), m_TorchConstantInt(&step)) - && step == 1 - && start == 0 - && end == std::numeric_limits::max()) - return getOperand(0); + int64_t start, end, step; + if (matchPattern(getStart(), m_TorchConstantInt(&start)) && + matchPattern(getEnd(), m_TorchConstantInt(&end)) && + matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 && + start == 0 && end == std::numeric_limits::max()) + return getOperand(0); auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); @@ -2744,7 +2749,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); - if (auto tensorIntOp = getA().getDefiningOp()) + if (auto tensorIntOp = getA().getDefiningOp()) return tensorIntOp.getT(); return nullptr; } @@ -2955,7 +2960,6 @@ LogicalResult AtenPermuteOp::verify() { << " elements, the output has rank " << outRank << '.'; } - // Initialization of the reverse permutation. -1 denotes an unknown // permutation index. SmallVector reversePermutation(outRank, -1); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index a154fb465..7e3f37a7b 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -440,7 +440,7 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (auto integerType = dtype.dyn_cast()) { return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); - } else if (dtype.isa()){ + } else if (dtype.isa()) { return dtype; } @@ -556,9 +556,9 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { // TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType // and AnyTorchType generate the exact same code (in TorchOps.cpp.inc). -// Unfortunately the generated implementations aren't visible/exposed ("static" linkage) -// and the predicates themselves can't be added/used in the specification of the parameters -// of the Torch_DictType. +// Unfortunately the generated implementations aren't visible/exposed ("static" +// linkage) and the predicates themselves can't be added/used in the +// specification of the parameters of the Torch_DictType. static bool isAnyTorchDictKeyType(Type type) { return type.isa() || type.isa() || type.isa() || type.isa() || diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f4e8a60ec..d1794de93 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -355,7 +355,7 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, auto rhsType = rhs.getType().cast(); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() - : rhsType.getOptionalDtype(); + : rhsType.getOptionalDtype(); llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { @@ -457,7 +457,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, return success(); } - static Value performLastReduceAndPermute(PatternRewriter &rewriter, Location loc, Type outType, Value input, @@ -1269,7 +1268,8 @@ public: }; } // namespace -// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp` +// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into +// `AtenMinDimOp` namespace { template class DecomposeAtenArgMinMaxOp : public OpRewritePattern { @@ -1300,9 +1300,9 @@ public: .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. - // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input - // tensor is flattened to 1d tensor and then the reduction happens on the - // 0th dimension. + // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so + // first the input tensor is flattened to 1d tensor and then the reduction + // happens on the 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = inputType @@ -1317,11 +1317,11 @@ public: } Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, keepDim) - .getIndices(); - + rewriter + .create(loc, valueTensorType, indicesTensorType, input, + dim, keepDim) + .getIndices(); + rewriter.replaceOp(op, resultArg); return success(); } @@ -1959,10 +1959,12 @@ public: // Define λ and α double scale = 1.0507009873554804934193349852946; double alpha = 1.6732632423543772848170429916717; - + // Create constants for λ and α - Value scaleVal = rewriter.create(loc, rewriter.getF64FloatAttr(scale)); - Value alphaVal = rewriter.create(loc, rewriter.getF64FloatAttr(alpha)); + Value scaleVal = rewriter.create( + loc, rewriter.getF64FloatAttr(scale)); + Value alphaVal = rewriter.create( + loc, rewriter.getF64FloatAttr(alpha)); // Create zero tensor for comparison Value constantZero = @@ -1972,17 +1974,21 @@ public: // Calculate positive and negative parts Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, input); Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); Value expInput = rewriter.create(loc, resType, minZeroX); - Value expInputMinusOne = rewriter.create(loc, resType, expInput, constantOne, constantOne); - Value negativeOutput = rewriter.create(loc, resType, expInputMinusOne, alphaVal); + Value expInputMinusOne = rewriter.create( + loc, resType, expInput, constantOne, constantOne); + Value negativeOutput = rewriter.create( + loc, resType, expInputMinusOne, alphaVal); // Multiply the result by λ Value seluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); - seluOutput = rewriter.create(loc, resType, seluOutput, scaleVal); + seluOutput = + rewriter.create(loc, resType, seluOutput, scaleVal); // Replace the original operation rewriter.replaceOp(op, seluOutput); @@ -2592,79 +2598,89 @@ public: namespace { - static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, - Location loc, Value input, - int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(input.getType().cast(), - dimA, dimB, transposedType))) - return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); - return success(); - } - - class DecomposeAtenConvTbcOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvTbcOp op, - PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), - Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); - Value oneList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1))}); - Value padding = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{op.getPad()}); - Value groups = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1)); - - // convtbc has WNC layout for input and output - // and WCF layout for weight - // whereas Convolution is going to use Conv1DNcwFcwOp for 1d - // which means we need the inputs in NCW and the weight in FCW - Value selfWnc = op.getSelf(); - Value selfNwc; - Value selfNcw; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc))) - return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc"); - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw))) - return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw"); - - Value weightWcf = op.getWeight(); - Value weightFcw; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw))) - return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw"); - - - Value outputNcw = rewriter.create( - op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/oneList, - /*padding*/ padding, /*dilation*/ oneList, - /*transpose*/ cstFalse, /*output_padding*/ emptyList, - groups); - - // convert output from Ncw to Wnc - Value outputNwc; - Value outputWnc; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc))) - return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc"); - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc))) - return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc"); - rewriter.replaceOp(op, outputWnc); - - return success(); - } - }; +static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, + int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); } +class DecomposeAtenConvTbcOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTbcOp op, + PatternRewriter &rewriter) const override { + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value oneList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1))}); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{op.getPad()}); + Value groups = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + + // convtbc has WNC layout for input and output + // and WCF layout for weight + // whereas Convolution is going to use Conv1DNcwFcwOp for 1d + // which means we need the inputs in NCW and the weight in FCW + Value selfWnc = op.getSelf(); + Value selfNwc; + Value selfNcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, + 0, 1, selfNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, + 1, 2, selfNcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Ncw"); + + Value weightWcf = op.getWeight(); + Value weightFcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + weightWcf, 0, 2, weightFcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose weight to Fcw"); + + Value outputNcw = rewriter.create( + op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), + /*stride*/ oneList, + /*padding*/ padding, /*dilation*/ oneList, + /*transpose*/ cstFalse, /*output_padding*/ emptyList, groups); + + // convert output from Ncw to Wnc + Value outputNwc; + Value outputWnc; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNcw, 1, 2, outputNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNwc, 0, 1, outputWnc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Wnc"); + rewriter.replaceOp(op, outputWnc); + + return success(); + } +}; +} // namespace // Decompose aten.conv1d to aten.convolution namespace { @@ -3815,8 +3831,8 @@ public: /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); Value stdRandN = rewriter.create(loc, resultType, randN, std); - rewriter.replaceOpWithNewOp(op, resultType, stdRandN, - mean, /*alpha=*/one); + rewriter.replaceOpWithNewOp(op, resultType, stdRandN, mean, + /*alpha=*/one); return success(); } }; @@ -6654,8 +6670,10 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal>(patterns); - addPatternIfTargetOpIsIllegal>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6768,8 +6786,6 @@ public: addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - - GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17..239960629 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -170,8 +170,8 @@ private: auto attr = std::get<1>(t); nameStack.push_back(attr.getName().str()); if (attr.getType().isa()) { - if (failed( - recursivelyTraverse(slot.getValue().getDefiningOp()))) + if (failed(recursivelyTraverse( + slot.getValue().getDefiningOp()))) return failure(); } else if (usedSlots.find(slot) != usedSlots.end()) { // Only create the GlobalSlotOp if the slot is used at all. @@ -190,8 +190,8 @@ private: } for (auto method : classType.getOps()) { nameStack.push_back(method.getName().str()); - funcLinkageInfo[{nnModule, - symbolTable.lookup(method.getFunction())}] = + funcLinkageInfo[{ + nnModule, symbolTable.lookup(method.getFunction())}] = LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()}; nameStack.pop_back(); } @@ -501,21 +501,24 @@ static LogicalResult rewriteMonomorphizedFuncClone( SmallVector toErase; auto handlePrimSetAttr = [&](PrimSetAttrOp op) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) affectedSlot = slot; } OpBuilder(op).create( - op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), + op.getLoc(), + objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), op.getValue()); toErase.push_back(op); return WalkResult::advance(); }; auto handlePrimGetAttr = [&](PrimGetAttrOp op) { if (!op.getType().isa()) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index c67e6dc0d..1e8c90dea 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -163,7 +163,8 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { } if (auto globalSlotSet = dyn_cast(op)) { auto *state = getOrCreate( - getProgramPoint(globalSlotSet.getSlotAttr())); + getProgramPoint( + globalSlotSet.getSlotAttr())); propagateIfChanged(state, state->setSafe(false)); } // Save the InitializeGlobalSlotsOp for later referencee @@ -211,8 +212,8 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { auto it = llvm::find(initializeGlobalSlotsOp.getSlotSymNames(), static_cast(flatSymbolRefPoint->getValue())); - Value value = initializeGlobalSlotsOp->getOperand( - std::distance(initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); + Value value = initializeGlobalSlotsOp->getOperand(std::distance( + initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); auto *flatSymbolRefState = getOrCreateFor(value, flatSymbolRefPoint); @@ -331,7 +332,8 @@ class InlineGlobalSlotsPass DenseSet safeToInline; for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); Value operand = initialize.getOperand(i); auto symbolRefPoint = solver.getProgramPoint( initialize.getSlotSymNames()[i].cast()); @@ -405,7 +407,8 @@ class InlineGlobalSlotsPass SmallVector newSlotSymNames; SmallVector newInitialValues; for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); if (!safeToInline.count(slotSymName)) { newSlotSymNames.push_back(slotSymName); newInitialValues.push_back(initialize.getOperand(i)); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 34874cb59..befdf808a 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -202,15 +202,16 @@ static bool satisfiesBackendContract(ModuleOp module, // Check for unimplemented operators first to give more direct diagnostics. walkResult0 = module.walk([&](Torch::OperatorOp op) { if (llvm::all_of(op.getResults(), [&op](auto res) { - return succeeded( - checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + return succeeded(checkType(op.getOperation(), res.getType(), + /*actuallyEmitDiagnostics=*/false)); })) { return WalkResult::advance(); } if (actuallyEmitDiagnostics) { - op->emitError("unsupported by backend contract: Unimplemented operator '" - + op.getName() + "'"); + op->emitError( + "unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); } return WalkResult::interrupt(); }); @@ -309,20 +310,22 @@ public: << " iterations of the simplification pipeline\n"; }); } + private: llvm::StringSet<> backendLegalOpsSet; }; class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase { + : public VerifyBackendContractNoDecompositionsBase< + VerifyBackendContractNoDecompositionsPass> { public: VerifyBackendContractNoDecompositionsPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, /*decompose*/false, - /*backendLegalOpsSet*/{}); + getBackendContractTarget(context, /*decompose*/ false, + /*backendLegalOpsSet*/ {}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 7c3ceab3a..a34e0208c 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -158,9 +158,11 @@ void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library, } } -FailureOr Torch::adjustFunctionArg( - OpBuilder &b, Location loc, Value operand, Type desiredType, - function_ref baseTransformation) { +FailureOr +Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, + Type desiredType, + function_ref + baseTransformation) { operand = baseTransformation(b, loc, operand, desiredType); // No need for adjustment if they already match. diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 6860fbb6e..fbbd6c480 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -90,7 +90,8 @@ public: PatternRewriter &rewriter) const override { SmallVector> ranks; SmallVector dtypes; - if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) { + if (!matchPattern(op.getRanks(), + m_TorchListOfOptionalConstantInts(ranks))) { return rewriter.notifyMatchFailure( op, "Expected `ranks` to be a list of optional constant ints"); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index bf371d7c4..e2abee51b 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -344,9 +344,9 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, // Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If // yes, then computes the final broadcast shape. void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, - SmallVector &resultShape, - SmallVector &resultShapeValue) { + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue) { SmallVector shapeA{ inputA.getType().cast().getSizes()}; SmallVector shapeB{ @@ -514,7 +514,7 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, } LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, - int64_t dimB, Type &transposedType) { + int64_t dimB, Type &transposedType) { if (!inType.hasSizes()) return failure(); SmallVector shape(inType.getSizes()); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 4d38f4965..ac9a72586 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -54,14 +54,14 @@ void TorchConversionDialect::initialize() { addInterfaces(); } - //===----------------------------------------------------------------------===// // Constant materializer. //===----------------------------------------------------------------------===// Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { + Attribute value, + Type type, + Location loc) { if (auto integerType = type.dyn_cast()) return builder.create(loc, value.cast()); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 8a5c218e4..1cda55724 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return std::nullopt; + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, Float64Type type, ValueRange inputs, + Location loc) -> std::optional { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -133,22 +133,23 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional { - return IntegerType::get(type.getContext(), 64); - }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addConversion( + [](Torch::GeneratorType type) -> std::optional { + return IntegerType::get(type.getContext(), 64); + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 175a3cd14..514d05234 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -18,8 +18,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; @@ -65,7 +65,8 @@ public: auto getConstantIntegerFromDefiningOp = [](Value operand, int &extractedInt) { - auto castOp = dyn_cast(operand.getDefiningOp()); + auto castOp = + dyn_cast(operand.getDefiningOp()); if (!castOp) { return failure(); } @@ -83,7 +84,8 @@ public: return failure(); } int unpackedBitWidth; - if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) { + if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, + unpackedBitWidth))) { return failure(); } if (unpackedBitWidth != @@ -103,32 +105,35 @@ public: // expand lhs std::vector lhsExpandedShape = {lhsShape[0], lhsShape[1], lhsReductDimSize / gs, gs}; - RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); + RankedTensorType lhsExpandedType = + RankedTensorType::get(lhsExpandedShape, elementType); SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; Value lhsExpanded = rewriter.create( - loc, lhsExpandedType, lhs, lhsReassociation); + loc, lhsExpandedType, lhs, lhsReassociation); // expand rhs - std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs}; - RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); + std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs, + gs}; + RankedTensorType rhsExpandedType = + RankedTensorType::get(rhsExpandedShape, rhsElementType); SmallVector rhsReassociation = {{0}, {1, 2}}; Value rhsExpanded = rewriter.create( - loc, rhsExpandedType, rhsQuant, rhsReassociation); + loc, rhsExpandedType, rhsQuant, rhsReassociation); Value cst0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); + loc, FloatAttr::get(elementType, 0.0)); - Value emptyDequant = rewriter.create( - loc, rhsExpandedShape, elementType); + Value emptyDequant = + rewriter.create(loc, rhsExpandedShape, elementType); SmallVector dynDims; for (int i = 0; i < lhsType.getRank(); i++) { if (lhsType.isDynamicDim(i)) { dynDims.push_back(rewriter.create(loc, lhs, i)); } } - Value empty = rewriter.create( - loc, resultShape, elementType, dynDims); - Value output = rewriter.create( - loc, cst0, empty).getResult(0); + Value empty = rewriter.create(loc, resultShape, + elementType, dynDims); + Value output = + rewriter.create(loc, cst0, empty).getResult(0); AffineExpr d0, d1, d2, d3, d4; bindDims(getContext(), d0, d1, d2, d3, d4); @@ -141,12 +146,12 @@ public: SmallVector dqIndexingMaps = {map, map1, map1, map}; SmallVector matIndexingMaps = {map2, map3, map4}; - SmallVector dequantIteratorTypes(3, utils::IteratorType::parallel); + SmallVector dequantIteratorTypes( + 3, utils::IteratorType::parallel); SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction, - utils::IteratorType::reduction - }; + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction}; Value rhsDequant = rewriter @@ -157,9 +162,12 @@ public: /*iteratorTypes=*/dequantIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value w = args[0], scale = args[1], zeroPoint = args[2]; - Value extw = b.create(loc, rewriter.getI32Type(), w); - Value fp_extw = b.create(loc, rewriter.getF16Type(), extw); - Value shifted = b.create(loc, fp_extw, zeroPoint); + Value extw = + b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create( + loc, rewriter.getF16Type(), extw); + Value shifted = + b.create(loc, fp_extw, zeroPoint); Value dqw = b.create(loc, shifted, scale); b.create(loc, dqw); }) @@ -168,8 +176,8 @@ public: Value matmulDequant = rewriter .create( - loc, output.getType(), - ValueRange{lhsExpanded, rhsDequant}, output, + loc, output.getType(), ValueRange{lhsExpanded, rhsDequant}, + output, /*indexingMaps=*/matIndexingMaps, /*iteratorTypes=*/matmulIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -188,7 +196,8 @@ public: namespace { class ConvertCustomQuantOpPass - : public TorchConversion::ConvertCustomQuantOpBase { + : public TorchConversion::ConvertCustomQuantOpBase< + ConvertCustomQuantOpPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -213,8 +222,8 @@ class ConvertCustomQuantOpPass target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 93d7de825..5ad3fa1c9 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -33,7 +33,6 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; - namespace { class VerifyLinalgOnTensorsBackendContractPass : public VerifyLinalgOnTensorsBackendContractBase< @@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) - << "Module does not conform to the linalg-on-tensors backend contract. " + << "Module does not conform to the linalg-on-tensors backend " + "contract. " "See dialect conversion legality information above."; return signalPassFailure(); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 888f29ade..c6085f419 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -45,7 +45,8 @@ class VerifyStablehloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp( + opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp index bd4fe52b7..dc0448796 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp @@ -31,18 +31,18 @@ TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) PRINT_FUNCTION(); } TorchMlirBackendData::TorchMlirBackendData( - BackendDevice device, Shape shape, std::shared_ptr info) + BackendDevice device, Shape shape, std::shared_ptr info) : BackendData(device, shape), info_(info) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Scalar& scalar, BackendDevice device) +TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar, + BackendDevice device) : BackendData(device, Shape(scalar.type(), {})), info_(std::make_shared(scalar)) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape) +TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor, + BackendDevice device, Shape shape) : BackendData(device, shape), info_(std::make_shared(tensor)) { PRINT_FUNCTION(); @@ -52,19 +52,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() { return reinterpret_cast(this); } -void TorchMlirBackendData::Assign(const BackendData& data) { - const TorchMlirBackendData* torch_mlir_data = - dynamic_cast(&data); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); +void TorchMlirBackendData::Assign(const BackendData &data) { + const TorchMlirBackendData *torch_mlir_data = + dynamic_cast(&data); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); info_ = torch_mlir_data->info_; } bool TorchMlirBackendData::HasValue() const { return bool(info_); } -BackendData::Info* TorchMlirBackendData::mlir_info() const { +BackendData::Info *TorchMlirBackendData::mlir_info() const { return info_.get(); } @@ -77,8 +76,8 @@ void TorchMlirBackendImpl::PrepareToExit() const {} * IR Tracing * */ -const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { - static const IrBuilder* builder = new TorchMlirIrBuilder(); +const IrBuilder *TorchMlirBackendImpl::GetIrBuilder() const { + static const IrBuilder *builder = new TorchMlirIrBuilder(); return builder; } @@ -87,28 +86,29 @@ const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { * */ BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const { + const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(tensor, device, shape); } BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const { + const at::Scalar &scalar, const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(scalar, device); } -BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const { +BackendDataPtr +TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const { PRINT_FUNCTION(); return std::make_shared(device, shape); } BackendDataPtr -TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const { +TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const { PRINT_FUNCTION(); - const auto* device_data_node = dynamic_cast(node); + const auto *device_data_node = dynamic_cast(node); if (!device_data_node) { return nullptr; } @@ -120,14 +120,13 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( c10::optional logical_scalar_type) const { PRINT_FUNCTION(); - TorchMlirBackendData* torch_mlir_data = - dynamic_cast(data.get()); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + TorchMlirBackendData *torch_mlir_data = + dynamic_cast(data.get()); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); - TorchMlirBackendData::Info* info = - dynamic_cast(torch_mlir_data->mlir_info()); + TorchMlirBackendData::Info *info = + dynamic_cast(torch_mlir_data->mlir_info()); TORCH_CHECK( info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); @@ -140,17 +139,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( * */ std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) const { + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)); } -std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device) const { +std::unique_ptr +TorchMlirBackendImpl::CreateLoweringContext(const std::string &name, + BackendDevice device) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device)); @@ -175,9 +176,8 @@ at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const { // Query all available backend devices std::vector TorchMlirBackendImpl::GetBackendDevices() const { PRINT_FUNCTION(); - return { - GetBackendDevice(c10::Device(c10::kLazy, 0)), - GetBackendDevice(c10::Device(c10::kCPU, 0))}; + return {GetBackendDevice(c10::Device(c10::kLazy, 0)), + GetBackendDevice(c10::Device(c10::kCPU, 0))}; } // Map a particular c10:: device to a concrete backend device diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.h b/projects/ltc/csrc/base_lazy_backend/backend_impl.h index c77033593..4029cab1e 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.h +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.h @@ -41,27 +41,28 @@ public: name = ss.str(); ++i; } - Info(const Info& other) + Info(const Info &other) : tensor{other.tensor}, scalar{other.scalar}, requires_grad{other.requires_grad}, name{other.name} {} - Info(const at::Tensor& tensor) + Info(const at::Tensor &tensor) : tensor{tensor}, requires_grad{tensor.requires_grad()} {} - Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {} + Info(const at::Scalar &scalar) : scalar{scalar}, requires_grad(false) {} }; TorchMlirBackendData(BackendDevice device, Shape shape); - TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr info); - TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); - TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape); + TorchMlirBackendData(BackendDevice device, Shape shape, + std::shared_ptr info); + TorchMlirBackendData(const at::Scalar &scalar, BackendDevice device); + TorchMlirBackendData(const at::Tensor &tensor, BackendDevice device, + Shape shape); virtual BackendData::Handle GetHandle() override; - virtual void Assign(const BackendData& data) override; + virtual void Assign(const BackendData &data) override; virtual bool HasValue() const override; - BackendData::Info* mlir_info() const; + BackendData::Info *mlir_info() const; protected: std::shared_ptr info_; @@ -80,7 +81,7 @@ public: * IR Tracing * */ - const IrBuilder* GetIrBuilder() const override; + const IrBuilder *GetIrBuilder() const override; /** * Configuration @@ -91,19 +92,22 @@ public: * Data Transfer * */ - virtual BackendDataPtr MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromTensor(const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const override; - virtual BackendDataPtr MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromScalar(const at::Scalar &scalar, + const BackendDevice &device) const override; - virtual BackendDataPtr CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const override; + virtual BackendDataPtr + CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const override; // Gets backend data if the node is a device data node. Otherwise returns // nullptr. - virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override; + virtual BackendDataPtr + GetComputationDataFromNode(const Node *) const override; virtual at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, @@ -113,13 +117,14 @@ public: * Lowering, Compilation, Execution * */ - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const override; - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, + BackendDevice device) const override; // TODO(whc) need to keep this? // virtual std::vector GetCompilationDevices( diff --git a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp index ca6d80f1f..c11c1563b 100644 --- a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp +++ b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp @@ -16,20 +16,18 @@ namespace torch { namespace lazy { DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed) - : TorchMlirNode( - op, operands, /*num_outputs=*/1, - /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} + : TorchMlirNode(op, operands, /*num_outputs=*/1, + /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} std::string DimensionNode::ToString() const { return "DimensionNode"; } SizeNode::SizeNode(Value input, size_t dim) - : DimensionNode( - OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, - MHash(dim)), + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, + MHash(dim)), dim_(dim){}; int64_t SizeNode::getStaticValue() const { - return dynamic_cast(operand(0).node) + return dynamic_cast(operand(0).node) ->shape(0) .size(dim_); } @@ -40,8 +38,9 @@ SizeAdd::SizeAdd(Value a, Value b) : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; int64_t SizeAdd::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() + - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() + + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeAdd::ToString() const { return "SizeAdd"; } @@ -50,8 +49,9 @@ SizeMul::SizeMul(Value a, Value b) : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; int64_t SizeMul::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() * - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() * + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeMul::ToString() const { return "SizeMul"; } @@ -61,11 +61,12 @@ SizeDiv::SizeDiv(Value a, Value b) int64_t SizeDiv::getStaticValue() const { TORCH_CHECK( - dynamic_cast(operand(1).node)->getStaticValue() != + dynamic_cast(operand(1).node)->getStaticValue() != 0, "Can't divide a dimension by zero"); - return dynamic_cast(operand(0).node)->getStaticValue() / - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() / + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeDiv::ToString() const { return "SizeDiv"; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp index 7e6f40c5c..a27889ad0 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -12,14 +12,14 @@ #include -#include -#include -#include -#include -#include "torch-mlir-c/Registration.h" -#include "torch-mlir-c/Transforms.h" #include "mlir-c/IR.h" #include "mlir-c/Pass.h" +#include "torch-mlir-c/Registration.h" +#include "torch-mlir-c/Transforms.h" +#include +#include +#include +#include #include "backend_impl.h" #include "jit_ir_importer/function_importer.h" @@ -38,8 +38,8 @@ namespace lazy { // TorchMlir Lowering Context /////////////////////////////////////////////////////////////////////////////// -TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device) +TorchMlirLoweringContext::TorchMlirLoweringContext(const std::string &name, + BackendDevice device) : LoweringContext(name, std::forward(device)), graph_(std::make_shared()), function_( @@ -49,11 +49,12 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) : LoweringContext( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)), graph_(std::make_shared()), function_( @@ -66,9 +67,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } } -void TorchMlirLoweringContext::Lower(const Node* node) { - if (auto* torch_mlir_node = - dynamic_cast(node)) { +void TorchMlirLoweringContext::Lower(const Node *node) { + if (auto *torch_mlir_node = + dynamic_cast(node)) { TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); CHECK(!ops.empty()) << "Failed to lower: " << *node; TORCH_CHECK_EQ(node->num_outputs(), ops.size()); @@ -82,19 +83,19 @@ void TorchMlirLoweringContext::Lower(const Node* node) { } void TorchMlirLoweringContext::SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, bool must_alias) { + const std::vector &output_index, int64_t param_number, + const std::vector ¶m_index, bool must_alias) { input_output_aliases_.push_back( {output_index, param_number, param_index, must_alias}); } bool TorchMlirLoweringContext::CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) { - TORCH_CHECK( - result_idx < root_tuple_.size(), "Tried getting result shape at index ", - result_idx, " which is out of bounds!"); + const BackendDataPtr ¶meter_data, size_t result_idx) { + TORCH_CHECK(result_idx < root_tuple_.size(), + "Tried getting result shape at index ", result_idx, + " which is out of bounds!"); - torch::jit::Value* output = root_tuple_[result_idx]; + torch::jit::Value *output = root_tuple_[result_idx]; if (c10::TensorTypePtr tensor_type = output->type()->cast()) { @@ -111,7 +112,7 @@ bool TorchMlirLoweringContext::CheckResultShape( return false; } -size_t TorchMlirLoweringContext::AddResult(const Output& output) { +size_t TorchMlirLoweringContext::AddResult(const Output &output) { PRINT_FUNCTION(); return AddResult(GetOutputOp(output)); @@ -120,9 +121,10 @@ size_t TorchMlirLoweringContext::AddResult(const Output& output) { // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. -void TorchMlirLoweringContext::AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) { +void TorchMlirLoweringContext::AddParameter(const torch::lazy::Output &output, + size_t index, + const torch::lazy::Shape &shape, + const std::string &name) { UNIMPLEMENTED_FUNCTION_ERROR(); } @@ -136,7 +138,7 @@ ComputationPtr TorchMlirLoweringContext::Build() { torch::jit::RefineTupleTypes(graph_); // Insert return values into graph. - for (torch::jit::Value* output : root_tuple_) { + for (torch::jit::Value *output : root_tuple_) { graph_->block()->registerOutput(output); } @@ -152,7 +154,6 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - // Convert MlirOperation to MlirModule. MlirLocation loc = mlirLocationUnknownGet(mlir_context_); MlirModule module_op = mlirModuleCreateEmpty(loc); @@ -162,14 +163,10 @@ ComputationPtr TorchMlirLoweringContext::Build() { // Apply passes to verify generated MLIR. auto pass_manager = mlirPassManagerCreate(mlir_context_); mlirPassManagerAddOwnedPass( - pass_manager, - mlirCreateVerifyBackendContractNoDecompositions() - ); + pass_manager, mlirCreateVerifyBackendContractNoDecompositions()); - MlirLogicalResult result = mlirPassManagerRunOnOp( - pass_manager, - mlirModuleGetOperation(module_op) - ); + MlirLogicalResult result = + mlirPassManagerRunOnOp(pass_manager, mlirModuleGetOperation(module_op)); if (mlirLogicalResultIsFailure(result)) { throw std::runtime_error("MLIR verification has failed."); @@ -178,12 +175,14 @@ ComputationPtr TorchMlirLoweringContext::Build() { return CreateComputation(module_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { - return std::make_shared( - module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); +ComputationPtr +TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { + return std::make_shared(module_op, mlir_context_, + graph_, parameter_names_, + input_output_aliases_); } -torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { +torch::jit::Value *TorchMlirLoweringContext::GetOutputOp(const Output &output) { PRINT_FUNCTION(); auto it = emitted_outputs_.find(output); @@ -195,15 +194,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); - TORCH_CHECK( - it != emitted_outputs_.end(), - "No MLIR operation emitted for output: ", output.ToString()); + TORCH_CHECK(it != emitted_outputs_.end(), + "No MLIR operation emitted for output: ", output.ToString()); } return it->second; } -void TorchMlirLoweringContext::AssignOutputOp( - const Output& output, torch::jit::Value* op) { +void TorchMlirLoweringContext::AssignOutputOp(const Output &output, + torch::jit::Value *op) { PRINT_FUNCTION(); auto torch_mlir_node = @@ -211,48 +209,44 @@ void TorchMlirLoweringContext::AssignOutputOp( std::vector source_files, functions; std::vector line_numbers; - const auto& metadata = torch_mlir_node->metadata(); - const auto& frames = metadata.frame_info; + const auto &metadata = torch_mlir_node->metadata(); + const auto &frames = metadata.frame_info; if (!frames.empty()) { static std::vector g_roots = - string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); + string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); std::for_each(frames.rbegin(), frames.rend(), - [&](const torch::lazy::SourceLocation& location) { - functions.push_back(location.function); - line_numbers.push_back(location.line); + [&](const torch::lazy::SourceLocation &location) { + functions.push_back(location.function); + line_numbers.push_back(location.line); - std::string file_name = location.file; - for (const std::string& root : g_roots) { - if (startswith(file_name, root)) { - // location.file starts with root, strip it off - file_name = file_name.substr(root.size()); - break; - } - } - source_files.push_back(file_name); - }); + std::string file_name = location.file; + for (const std::string &root : g_roots) { + if (startswith(file_name, root)) { + // location.file starts with root, strip it off + file_name = file_name.substr(root.size()); + break; + } + } + source_files.push_back(file_name); + }); if (!source_files.empty()) { - op->node()->ss_( - c10::Symbol::attr("source_files"), source_files); - op->node()->ss_( - c10::Symbol::attr("functions"), functions); - op->node()->is_( - c10::Symbol::attr("line_numbers"), line_numbers); + op->node()->ss_(c10::Symbol::attr("source_files"), source_files); + op->node()->ss_(c10::Symbol::attr("functions"), functions); + op->node()->is_(c10::Symbol::attr("line_numbers"), line_numbers); } } auto scope = ::c10::Symbol::scope(metadata.scope); - op->node()->setScope( - c10::make_intrusive()->push(scope)); + op->node()->setScope(c10::make_intrusive()->push(scope)); emitted_outputs_[output] = std::move(op); } -torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { +torch::jit::Value *TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { PRINT_FUNCTION(); - if (!dynamic_cast(data.get())) { + if (!dynamic_cast(data.get())) { TORCH_CHECK( false, "Expected TorchMlirBackendData. Got some other BackendData type"); @@ -263,20 +257,21 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - torch::jit::Value* param = + torch::jit::Value *param = graph_->addInput(c10::str("p", parameters_.size())); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info, "Expected TorchMlirBackendData::Info"); if (info->scalar.has_value()) { - auto& scalar = info->scalar.value(); + auto &scalar = info->scalar.value(); if (scalar.isFloatingPoint()) { param->setType(c10::FloatType::get()); } else if (scalar.isIntegral(true)) { param->setType(c10::IntType::get()); } else { - TORCH_CHECK( - false, "Unhandled scalar type: ", c10::toString(scalar.type())); + TORCH_CHECK(false, + "Unhandled scalar type: ", c10::toString(scalar.type())); } } else { // Save parameter shape information. @@ -305,7 +300,7 @@ std::shared_ptr TorchMlirLoweringContext::graph() const { return graph_; } -size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { +size_t TorchMlirLoweringContext::AddResult(torch::jit::Value *op) { PRINT_FUNCTION(); root_tuple_.push_back(std::move(op)); return root_tuple_.size() - 1; @@ -313,9 +308,9 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { // Sync vector of c10::Argument with type specified from parallel list of // jit::Value. There must be a 1:1 map between elements of args and values. -std::vector sync_argument_types( - const std::vector& args, - c10::ArrayRef values) { +std::vector +sync_argument_types(const std::vector &args, + c10::ArrayRef values) { TORCH_CHECK( args.size() == values.size(), "Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ", @@ -362,7 +357,7 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { TorchMlirComputation::TorchMlirComputation( MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, + const std::shared_ptr &graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), @@ -377,26 +372,25 @@ TorchMlirComputation::TorchMlirComputation( } } -int TorchMlirComputation::parameters_size() const { - return num_parameters_; -} +int TorchMlirComputation::parameters_size() const { return num_parameters_; } -const std::vector& +const std::vector & TorchMlirComputation::parameter_shapes() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return parameter_shapes_; } -const std::vector& TorchMlirComputation::parameter_names() const { +const std::vector &TorchMlirComputation::parameter_names() const { return parameter_names_; } -const std::unordered_map& TorchMlirComputation::parameters_map() const { +const std::unordered_map & +TorchMlirComputation::parameters_map() const { return parameters_map_; } -const torch::lazy::Shape& TorchMlirComputation::result_shape() const { +const torch::lazy::Shape &TorchMlirComputation::result_shape() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return result_shape_; @@ -411,13 +405,9 @@ MlirOperation TorchMlirComputation::func_op() const { return mlirBlockGetFirstOperation(block); } -MlirModule TorchMlirComputation::module_op() const { - return module_op_; -} +MlirModule TorchMlirComputation::module_op() const { return module_op_; } -MlirContext TorchMlirComputation::mlir_context() const { - return mlir_context_; -} +MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; } const std::string TorchMlirComputation::debug_string() const { std::stringstream ss; @@ -430,7 +420,7 @@ const std::string TorchMlirComputation::debug_string() const { // Parameter names ss << "Parameter names:\n"; - for (auto& p : parameter_names_) { + for (auto &p : parameter_names_) { ss << " " << p << "\n"; } ss << "\n"; @@ -451,10 +441,10 @@ const std::string TorchMlirComputation::debug_string() const { const std::string TorchMlirComputation::to_string() const { // Since we use the C-MLIR API, we need to use a callback to print. - MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + MlirStringCallback print_callback = [](MlirStringRef part, void *user_data) { // user_data is a void ptr to some data structure of our choice -- in this // case, the string stream where we'll be accumulating the strings. - std::stringstream* ss_ptr = static_cast(user_data); + std::stringstream *ss_ptr = static_cast(user_data); *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; @@ -462,7 +452,8 @@ const std::string TorchMlirComputation::to_string() const { // Setup flags for MLIR serialization. MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false); - mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss); + mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, + print_callback, &ss); return ss.str(); } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h index f62a71ce7..3b226b468 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h @@ -39,35 +39,34 @@ public: }; using InputOutputAliases = std::vector; - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device); - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device, - c10::ArrayRef post_order, - torch::lazy::Util::EmissionMap emit_status); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); - void Lower(const Node* node); + void Lower(const Node *node); // Adds a new input/output alias. - void SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, - bool must_alias = false) override; + void SetUpAlias(const std::vector &output_index, + int64_t param_number, const std::vector ¶m_index, + bool must_alias = false) override; // Check if parameter shape matches result at index. - bool CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) override; + bool CheckResultShape(const BackendDataPtr ¶meter_data, + size_t result_idx) override; // Adds the given output as a component of the result tuple and returns its // assigned position within the tuple. - size_t AddResult(const torch::lazy::Output& output) override; + size_t AddResult(const torch::lazy::Output &output) override; // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. - void AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) override; + void AddParameter(const torch::lazy::Output &output, size_t index, + const torch::lazy::Shape &shape, + const std::string &name) override; // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). @@ -78,27 +77,27 @@ public: // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the // corresponding TS operation returned. - torch::jit::Value* GetOutputOp(const Output& output); + torch::jit::Value *GetOutputOp(const Output &output); // Assigns the given TS operation to the specified output. As outputs are // lowered in a post-order fashion, later nodes should always find their // operands among the emitted outputs. - void AssignOutputOp(const Output& output, torch::jit::Value* op); + void AssignOutputOp(const Output &output, torch::jit::Value *op); // If a parameter associated with data has already been declared, it will be // returned. Otherwise a new one will be created, associated with the tensor // held in data. - torch::jit::Value* GetParameter(BackendDataPtr data); + torch::jit::Value *GetParameter(BackendDataPtr data); std::shared_ptr graph() const; protected: struct Parameter { - torch::jit::Value* param; + torch::jit::Value *param; size_t index = 0; }; - size_t AddResult(torch::jit::Value* op); + size_t AddResult(torch::jit::Value *op); // Creates a jit::Function from the current jit::Graph. Input and output // type information is patched to include shape. @@ -113,8 +112,8 @@ protected: MlirContext mlir_context_; std::unordered_map parameters_map_; std::unordered_map parameter_names_; - std::vector root_tuple_; - OutputMap emitted_outputs_; + std::vector root_tuple_; + OutputMap emitted_outputs_; }; class TORCH_API TorchMlirComputation : public torch::lazy::Computation { @@ -122,21 +121,20 @@ public: using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases; using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; - TorchMlirComputation( - MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, - std::unordered_map parameters_map, - InputOutputAliases input_output_aliases); + TorchMlirComputation(MlirModule module_op, MlirContext mlir_context, + const std::shared_ptr &graph, + std::unordered_map parameters_map, + InputOutputAliases input_output_aliases); int parameters_size() const override; - const std::vector& parameter_shapes() const override; + const std::vector ¶meter_shapes() const override; - const std::vector& parameter_names() const override; + const std::vector ¶meter_names() const override; - const std::unordered_map& parameters_map() const; + const std::unordered_map ¶meters_map() const; - const torch::lazy::Shape& result_shape() const override; + const torch::lazy::Shape &result_shape() const override; std::shared_ptr graph() const; diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp index 7d9fe056d..a0e4bae76 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -10,8 +10,8 @@ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// -#include #include +#include #include #include #include @@ -33,16 +33,16 @@ #include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" -#include "ops/to_copy.h" -#include "ops/unbind_int.h" -#include "ops/split.h" #include "ops/index.h" #include "ops/ivalue.h" +#include "ops/split.h" +#include "ops/to_copy.h" +#include "ops/unbind_int.h" #include "utils/exception.h" #include "utils/sys_utils.h" namespace { -at::Tensor to_meta(const at::Tensor& tensor) { +at::Tensor to_meta(const at::Tensor &tensor) { // undefined tensors can't be converted to the meta device, since they don't // have sizes/strides if (!tensor.defined()) @@ -60,7 +60,7 @@ at::Tensor to_meta(const at::Tensor& tensor) { return out; } -c10::optional to_meta(const c10::optional& tensor) { +c10::optional to_meta(const c10::optional &tensor) { if (tensor.has_value()) { return to_meta(*tensor); } @@ -70,16 +70,17 @@ c10::optional to_meta(const c10::optional& tensor) { std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; } -c10::List> to_meta(const c10::List>& t_list) { +c10::List> +to_meta(const c10::List> &t_list) { c10::List> outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; @@ -91,9 +92,9 @@ namespace lazy { namespace { -at::Tensor CreateLtcTensor( - const at::Tensor& tensor, - const c10::optional& device) { +at::Tensor +CreateLtcTensor(const at::Tensor &tensor, + const c10::optional &device) { if (tensor.defined() && device) { return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::LazyTensor::Create(tensor, *device)); @@ -102,7 +103,7 @@ at::Tensor CreateLtcTensor( } c10::optional -GetLtcDevice(const c10::optional& device) { +GetLtcDevice(const c10::optional &device) { if (!device) { return c10::nullopt; } @@ -112,24 +113,23 @@ GetLtcDevice(const c10::optional& device) { return torch::lazy::atenDeviceToBackendDevice(*device); } -torch::lazy::Value MaybeExpand( - const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) { +torch::lazy::Value MaybeExpand(const torch::lazy::Value &input, + const torch::lazy::Shape &target_shape) { if (input.shape().sizes() == target_shape.sizes()) { return input; } - return torch::lazy::MakeExpand( - input, target_shape.sizes().vec(), - /*is_scalar_expand=*/false); + return torch::lazy::MakeExpand(input, target_shape.sizes().vec(), + /*is_scalar_expand=*/false); } -void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { +void copy_(torch::lazy::LazyTensorPtr &input, torch::lazy::LazyTensorPtr &src) { if (input->GetDevice() == src->GetDevice()) { torch::lazy::Value copy_value; if (input->dtype() == src->dtype()) { copy_value = src->GetIrValue(); } else { - copy_value = torch::lazy::MakeCast( - src->GetIrValue(), input->dtype(), src->dtype()); + copy_value = torch::lazy::MakeCast(src->GetIrValue(), input->dtype(), + src->dtype()); } input->SetIrValue(MaybeExpand(copy_value, input->shape())); } else { @@ -146,15 +146,17 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. -at::Tensor LazyNativeFunctions::clone( - const at::Tensor& self, c10::optional memory_format) { +at::Tensor +LazyNativeFunctions::clone(const at::Tensor &self, + c10::optional memory_format) { auto self_lt = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); } -at::Tensor LazyNativeFunctions::_copy_from( - const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { +at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor &self, + const at::Tensor &dst, + bool non_blocking) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -199,16 +201,16 @@ at::Tensor LazyNativeFunctions::_copy_from( } } else { copy_(dst_tensor, self_tensor); - auto* impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *impl = + dynamic_cast(dst.unsafeGetTensorImpl()); impl->set_tensor(dst_tensor); } } return dst; } -at::Tensor LazyNativeFunctions::_copy_from_and_resize( - const at::Tensor& self, const at::Tensor& dst) { +at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self, + const at::Tensor &dst) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -223,8 +225,8 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { // at this point we know dst is a lazy tensor - auto* dest_impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *dest_impl = + dynamic_cast(dst.unsafeGetTensorImpl()); dest_impl->tensor()->UpdateFromTensorOut(self_tensor); dest_impl->force_refresh_sizes(); } @@ -232,15 +234,16 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( } at::Tensor LazyNativeFunctions::_to_copy( - const at::Tensor& self, c10::optional dtype, + const at::Tensor &self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) { PRINT_FUNCTION(); auto options = self.options(); if (dtype) { - // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... - // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it + // I put each of these setters in a conditional instead of doing + // `self.options().dtype(dtype).layout(layout)... because calling + // .dtype(nullopt) on an options() that already has dtype appears to wipe it options = options.dtype(dtype); } if (layout) { @@ -261,8 +264,9 @@ at::Tensor LazyNativeFunctions::_to_copy( if (!lazy_self && device && device->type() == c10::kLazy) { // Case 1: eager->lazy (we create a new lazy tensor) // See Note [Lazy Tensor Functionalization] - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. + // Invariant: if the functionalization key is in the exclude set, then we're + // expected to return an ordinary tensor, which will be "lifted" into a + // functional wrapper later. bool functionalize_output = !c10::impl::tls_local_dispatch_key_set().excluded_.has( c10::DispatchKey::Functionalize); @@ -270,7 +274,8 @@ at::Tensor LazyNativeFunctions::_to_copy( self, options, *device, /*non_blocking=*/non_blocking, /*functionalize_output=*/functionalize_output); } else if (device && device->type() != c10::kLazy) { - // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) + // Case 2: lazy->eager (forces a graph break since we are materializing a + // tensor) TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); @@ -278,22 +283,24 @@ at::Tensor LazyNativeFunctions::_to_copy( auto moved_eager_tensor = eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); return moved_eager_tensor; - } else if ( - device && device->type() == c10::kLazy && device->has_index() && - device->index() != self.device().index()) { + } else if (device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { // Case 3: lazy:0 -> lazy:1 // TODO(whc) what do we actually want to do here? // option 1: materialize, move eager tensor, create new lazy tensor - // - this should be our default, as it is what would happen before we implemented _to_copy + // - this should be our default, as it is what would happen before we + // implemented _to_copy // - actually combines case 1 + case 2 // option 2: support multiple devices inside one lazy/TS executor (case 4) - // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly + // - but: we may have other assumptions that there is just one device + // per executor? so don't take this lightly TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); // we move the eager tensor to the 'eager' equivalent of our lazy device - // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is + // what we use auto eager_device = c10::Device( torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); options = options.device(eager_device); @@ -305,12 +312,14 @@ at::Tensor LazyNativeFunctions::_to_copy( return torch::lazy::CreateAtenFromLtcTensor(lazy_self); } else { - // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy + // graph) - // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. - // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to - // convert an eager tensor back to a lazy one inside the torchscript executor - // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument + // Note: captured _to_copy will be executed with real eager tensors, not + // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this + // captured IR, or we will try to convert an eager tensor back to a lazy one + // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so + // we can safely drop the device argument device = c10::nullopt; auto shapes = torch::lazy::compute_shape__to_copy( @@ -327,257 +336,297 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::_unsafe_view( - const at::Tensor& self, at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self, + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); + return LazyNativeFunctions::view_copy_symint(self, + c10::fromIntArrayRefSlow(size)); } -at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { +at::Tensor LazyNativeFunctions::t(const at::Tensor &self) { TORCH_LAZY_FN_COUNTER("lazy::"); return at::functionalization::functionalize_aten_op::call(self); } -std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { +std::vector LazyNativeFunctions::unbind_copy(const at::Tensor &self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); + + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); - + auto out_meta = + at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); + std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, dim }; - const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, dim}; + const char *schema_str = + "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, std::move(shapes)); - CacheNode(node); - } - - std::vector result; - for (size_t i = 0; i < node->num_outputs(); ++i) { - result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); - } - - return result; -} - -std::vector LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { - TORCH_LAZY_FN_COUNTER("lazy::"); - auto common_device = torch::lazy::GetBackendDevice(self); - TORCH_INTERNAL_ASSERT(common_device); - - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); - if (!node) { - auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim); - - std::vector shapes; - for (const auto & shape : out_meta) { - shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); - } - - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_sizes, dim }; - const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); - } - - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, + std::move(shapes)); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -std::vector LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { +std::vector LazyNativeFunctions::split_with_sizes_copy_symint( + const at::Tensor &self, c10::SymIntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); + + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim); + auto out_meta = at::compositeexplicitautogradnonfunctional:: + split_with_sizes_copy_symint(self_meta, split_sizes, dim); std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); + } + + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_sizes, dim}; + const char *schema_str = "aten::split_with_sizes_copy(Tensor self, " + "SymInt[] split_sizes, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); + } + + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, + std::move(shapes)); + CacheNode(node); + } + + std::vector result; + for (size_t i = 0; i < node->num_outputs(); ++i) { + result.push_back( + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); + } + + return result; +} + +std::vector +LazyNativeFunctions::split_copy_symint(const at::Tensor &self, + c10::SymInt split_size, int64_t dim) { + TORCH_LAZY_FN_COUNTER("lazy::"); + auto common_device = torch::lazy::GetBackendDevice(self); + TORCH_INTERNAL_ASSERT(common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); + if (!node) { + auto self_meta = to_meta(self); + auto out_meta = + at::compositeexplicitautogradnonfunctional::split_copy_symint( + self_meta, split_size, dim); + + std::vector shapes; + for (const auto &shape : out_meta) { + shapes.push_back( + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } const size_t num_outputs = shapes.size(); - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_size, dim }; - const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_size, dim}; + const char *schema_str = "aten::split_copy.Tensor(Tensor self, SymInt " + "split_size, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, + std::move(shapes), num_outputs); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List> & indices) { +at::Tensor LazyNativeFunctions::index( + const at::Tensor &self, + const c10::List> &indices) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); std::vector values; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + values.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto list = MakeNode(values); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto out_meta = at::meta::index(self_meta, indices_meta); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices }; - const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices}; + const char *schema_str = + "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, + std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } -at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate) { +at::Tensor LazyNativeFunctions::index_put( + const at::Tensor &self, const c10::List> &indices, + const at::Tensor &values, bool accumulate) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_valeus = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); std::vector indices_vector; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + indices_vector.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto indices_list = MakeNode(indices_vector); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, + lazy_valeus->GetIrValue(), accumulate); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto values_meta = to_meta(values); - auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate); + auto out_meta = at::compositeexplicitautograd::index_put( + self_meta, indices_meta, values_meta, accumulate); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices, values }; - const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices, values}; + const char *schema_str = + "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool " + "accumulate=False) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes)); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), + accumulate, std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } // This is needed by the torch.tensor constructor. // LazyTensor always opts into functionalization. -// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object. -at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { +// "lifting" a tensor for functionalization means wrapping it in a +// FunctionalTensorWrapper object. +at::Tensor LazyNativeFunctions::lift(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { +at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -// All of the below ops correspond to CompositeExplicitAutograd kernels from core -// that call into view operators internally. -// These are all composite ops that LTC can technically re-use / get for free, -// but we need to "functionalize" them to remove the view ops before we can use them. +// All of the below ops correspond to CompositeExplicitAutograd kernels from +// core that call into view operators internally. These are all composite ops +// that LTC can technically re-use / get for free, but we need to +// "functionalize" them to remove the view ops before we can use them. at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op::call(tensors); } at::Tensor LazyNativeFunctions::new_empty_strided_symint( - const at::Tensor& self, - c10::SymIntArrayRef size, - c10::SymIntArrayRef stride, - c10::optional dtype, - c10::optional layout, - c10::optional device, + const at::Tensor &self, c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, c10::optional dtype, + c10::optional layout, c10::optional device, c10::optional pin_memory) { if (!device || device->type() == c10::DeviceType::Lazy) { - return at::functionalization::functionalize_aten_op_symint< - ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, - device, pin_memory); + return at::functionalization::functionalize_aten_op_symint::call(self, size, stride, dtype, layout, device, + pin_memory); } - // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") - // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + // For cases when device != lazy, for example: + // lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit + // functionalization. To do that we create regular cpu tensors. at::Tensor t = at::empty_symint( size, (dtype ? dtype : c10::optional(self.scalar_type())), (layout ? layout : c10::optional(self.layout())), device, @@ -585,65 +634,63 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( return t.as_strided_symint(size, stride, /*storage_offset=*/0); } -at::Tensor LazyNativeFunctions::narrow_copy_symint( - const at::Tensor& self, - int64_t dim, - c10::SymInt start, - c10::SymInt length) { +at::Tensor LazyNativeFunctions::narrow_copy_symint(const at::Tensor &self, + int64_t dim, + c10::SymInt start, + c10::SymInt length) { return at::functionalization::functionalize_aten_op_symint::call(self, dim, start, length); } -at::Tensor LazyNativeFunctions::pixel_shuffle( - const at::Tensor& self, int64_t upscale_factor) { +at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor &self, + int64_t upscale_factor) { return at::functionalization::functionalize_aten_op::call(self, upscale_factor); } -at::Tensor LazyNativeFunctions::pixel_unshuffle( - const at::Tensor& self, int64_t downscale_factor) { +at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor &self, + int64_t downscale_factor) { return at::functionalization::functionalize_aten_op::call(self, downscale_factor); } -at::Tensor LazyNativeFunctions::select_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, - int64_t index) { +at::Tensor LazyNativeFunctions::select_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t dim, int64_t index) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, index); } at::Tensor LazyNativeFunctions::slice_backward_symint( - const at::Tensor& grad_output, - at::SymIntArrayRef input_sizes, - int64_t dim, - c10::SymInt start, - c10::SymInt end, - c10::SymInt step) { + const at::Tensor &grad_output, at::SymIntArrayRef input_sizes, int64_t dim, + c10::SymInt start, c10::SymInt end, c10::SymInt step) { return at::functionalization::functionalize_aten_op_symint::call(grad_output, input_sizes, dim, start, end, step); } -at::Tensor LazyNativeFunctions::diagonal_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset, - int64_t dim1, int64_t dim2) { +at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t offset, int64_t dim1, + int64_t dim2) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, offset, dim1, dim2); } at::Tensor LazyNativeFunctions::_trilinear( - const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, + const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) { - return at::functionalization::functionalize_aten_op:: - call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + return at::functionalization::functionalize_aten_op::call(i1, i2, i3, expand1, expand2, expand3, sumdim, + unroll_dim); } at::Tensor LazyNativeFunctions::linalg_pinv( - const at::Tensor& self, const c10::optional& atol, - const c10::optional& rtol, bool hermitian) { + const at::Tensor &self, const c10::optional &atol, + const c10::optional &rtol, bool hermitian) { return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); } // functionalize_aten_op can't handle out= ops directly. -// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs. -at::Tensor& LazyNativeFunctions::logsumexp_out( - const at::Tensor& self, at::IntArrayRef dim, bool keepdim, - at::Tensor& out) { +// Instead, we can call the composite kernel from core, and copy and mutations +// back to the inputs. +at::Tensor &LazyNativeFunctions::logsumexp_out(const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim, at::Tensor &out) { auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); // directly call the composite kernel from core. diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp index 39dc1ad0c..0f31fab2c 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp @@ -18,11 +18,10 @@ namespace lazy { namespace { -hash_t OperandHashes( - const OpList& operands, const c10::ArrayRef& shapes, - const hash_t& seed, bool bakeInSizes) { +hash_t OperandHashes(const OpList &operands, const c10::ArrayRef &shapes, + const hash_t &seed, bool bakeInSizes) { hash_t hash = seed; - for (auto& operand : operands) { + for (auto &operand : operands) { if (!operand) { hash = HashCombine(hash, static_cast(kNullOpt)); continue; @@ -30,7 +29,7 @@ hash_t OperandHashes( auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); hash = HashCombine(hash, operand_hash); } - for (auto& shape : shapes) { + for (auto &shape : shapes) { hash = HashCombine(hash, shape.hash(bakeInSizes)); } return hash; @@ -38,53 +37,51 @@ hash_t OperandHashes( } // namespace - -// Adds a static hook that is run after every single TorchMlirNode is initialized -static std::vector> constructor_hooks; -void TorchMlirNode::addConstructorHook(std::function f) { +// Adds a static hook that is run after every single TorchMlirNode is +// initialized +static std::vector> constructor_hooks; +void TorchMlirNode::addConstructorHook(std::function f) { constructor_hooks.emplace_back(f); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, - hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + std::vector &&shapes, size_t num_outputs, + hash_t hash_seed) : Node(op, operands, std::move(shapes), num_outputs) { hash_seed = HashCombine(op.hash(), hash_seed); shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); - dag_hash_ = - (enableDynamicShape() - ? OperandHashes(operands, this->shapes(), hash_seed, false) - : shape_hash_); + dag_hash_ = (enableDynamicShape() + ? OperandHashes(operands, this->shapes(), hash_seed, false) + : shape_hash_); - for (std::function& f : constructor_hooks) { + for (std::function &f : constructor_hooks) { f(this); } } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) { +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, + size_t num_outputs, hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) { addComputedShape(shape_fn); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) {} +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) {} -TorchMlirNode::TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } - -TorchMlirNode* TorchMlirNode::mlir_node(int index) const { - return dynamic_cast(operands_.at(index).get()); +TorchMlirNode *TorchMlirNode::mlir_node(int index) const { + return dynamic_cast(operands_.at(index).get()); } /////////////////////////////////////////////////////////////////////////////// @@ -107,11 +104,12 @@ TorchMlirTensorList::TorchMlirTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); @@ -140,16 +138,17 @@ TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirOptionalTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); - auto listnode = - graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list)); + auto listnode = graph->insertNode(graph->createList( + c10::OptionalType::create(c10::TensorType::get()), tensor_list)); return {listnode->output()}; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.h b/projects/ltc/csrc/base_lazy_backend/mlir_node.h index a76ec0b05..e5738a921 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.h @@ -27,23 +27,22 @@ namespace lazy { class TORCH_API TorchMlirNode : public torch::lazy::Node { public: - TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, std::vector &&shapes, + size_t num_outputs, hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, - hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed = kHashSeed); - // Adds a static hook that is run after every single TorchMlirNode is constructed - static void addConstructorHook(std::function); + // Adds a static hook that is run after every single TorchMlirNode is + // constructed + static void addConstructorHook(std::function); ~TorchMlirNode() override = default; @@ -51,10 +50,10 @@ public: hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index) const; + TorchMlirNode *mlir_node(int index) const; - virtual TorchMlirOpVector - Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; + virtual TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const; private: // The hash of the dag WITH size info. Used for shape caching @@ -86,22 +85,23 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode { TorchMlirTensorList() = delete; TorchMlirTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; -// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent -// optional tensors, so the output type for this op is !torch.list>. +// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also +// represent optional tensors, so the output type for this op is +// !torch.list>. struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode { static OpKind ClassOpKind(); TorchMlirOptionalTensorList() = delete; TorchMlirOptionalTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; } // namespace lazy diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp index a21bb93f0..b52b724f0 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -31,21 +31,23 @@ namespace torch { namespace lazy { -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const std::vector tensor_types, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const std::vector tensor_types, + const std::vector &arguments, + const std::vector &kwarguments) { // Workaround for ListType::isSubtypeOfExt behavior which leads to // the problems with JIT schema matching, so we need to keep // c10::ListType empty before magic_method->call function call. auto dummy_graph = torch::jit::Graph(); for (auto arg : arguments) { - torch::jit::Value* value = arg.value(dummy_graph); + torch::jit::Value *value = arg.value(dummy_graph); if (value->type()->kind() == c10::TypeKind::ListType) { - auto list_element_type = value->type()->cast()->getElementType(); + auto list_element_type = + value->type()->cast()->getElementType(); if (list_element_type->cast()) { - value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get()))); + value->setType(c10::ListType::create( + c10::OptionalType::create(c10::TensorType::get()))); } else { value->setType(c10::ListType::create(c10::TensorType::get())); } @@ -56,25 +58,27 @@ TorchMlirOpVector LowerTorchMlirBuiltin( std::make_shared(sym, at::nullopt); auto magic_method = std::make_shared("", builtin); auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); - auto sv = dynamic_cast(ret.get()); + auto sv = dynamic_cast(ret.get()); CHECK(sv); TorchMlirOpVector results; if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) { - // Unpack dynamic multi-output operations like aten::split with Tensor[] output type. - // This is required to have consistent input types for multi-output node consumers. - torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size()); + // Unpack dynamic multi-output operations like aten::split with Tensor[] + // output type. This is required to have consistent input types for + // multi-output node consumers. + torch::jit::Node *node = function->graph()->createListUnpack( + sv->getValue(), tensor_types.size()); function->graph()->insertNode(node); - for (const auto & output : node->outputs()) { + for (const auto &output : node->outputs()) { results.push_back(output); } } else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { - // Op returns multiple values and the number of outputs is static and defined - // by the operation schema. + // Op returns multiple values and the number of outputs is static and + // defined by the operation schema. const auto tuple_call_result = sv->asTuple({}, *function); - for (const auto& tuple_component : tuple_call_result) { + for (const auto &tuple_component : tuple_call_result) { auto tuple_component_sv = - dynamic_cast(tuple_component.get()); + dynamic_cast(tuple_component.get()); results.push_back(tuple_component_sv->getValue()); } } else { @@ -84,7 +88,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin( // Insert known tensor type information. unsigned tensor_type_idx = 0; - for (jit::Value* value : results) { + for (jit::Value *value : results) { if (value->type()->kind() == c10::TypeKind::TensorType) { TORCH_CHECK( tensor_type_idx < tensor_types.size(), function->graph()->toString(), @@ -97,23 +101,22 @@ TorchMlirOpVector LowerTorchMlirBuiltin( } // Ensure that we use up all the known tensor type information available. - TORCH_CHECK( - tensor_type_idx == tensor_types.size(), tensor_type_idx, - " known types were injected into jit::Value, but ", tensor_types.size(), - " were provided from lazy::Node!"); + TORCH_CHECK(tensor_type_idx == tensor_types.size(), tensor_type_idx, + " known types were injected into jit::Value, but ", + tensor_types.size(), " were provided from lazy::Node!"); return results; } -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector &arguments, + const std::vector &kwarguments) { std::vector tensor_types; // Generate types with fixed tensor shape information. - for (const Shape& shape : result_shapes) { + for (const Shape &shape : result_shapes) { tensor_types.push_back(torch::jit::TensorType::create( /*scalar_type=*/shape.scalar_type(), /*device=*/c10::nullopt, @@ -122,34 +125,34 @@ TorchMlirOpVector LowerTorchMlirBuiltin( /*requires_grad=*/c10::nullopt)); } - return LowerTorchMlirBuiltin( - function, sym, tensor_types, arguments, kwarguments); + return LowerTorchMlirBuiltin(function, sym, tensor_types, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - const torch::lazy::Node* node, TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, node->op().op, node->shapes(), arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(const torch::lazy::Node *node, TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, node->op().op, node->shapes(), + arguments, kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const c10::ArrayRef result_shapes, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, sym, result_shapes, arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const c10::ArrayRef result_shapes, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, sym, result_shapes, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const std::vector types, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const std::vector types, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments); } -c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { +c10::TensorType &cast_tensor_type(c10::TypePtr value_type) { auto tensor_type = value_type->cast(); TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); @@ -157,8 +160,8 @@ c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { } c10::optional> -get_tensor_type_shape(c10::TensorType& tensor_type) { - auto& symbolic_shape = tensor_type.symbolic_sizes(); +get_tensor_type_shape(c10::TensorType &tensor_type) { + auto &symbolic_shape = tensor_type.symbolic_sizes(); if (!symbolic_shape.rank()) { return c10::nullopt; } @@ -175,21 +178,21 @@ get_tensor_type_shape(c10::TensorType& tensor_type) { } std::vector compute_shape_copy(c10::TypePtr value_type) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!"); auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to copy due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to copy due to lack of scalar type!"); return {Shape(scalar_type.value(), maybe_dims.value())}; } -std::vector compute_shape_slice( - c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end, - int64_t step) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); +std::vector compute_shape_slice(c10::TypePtr value_type, + int64_t dim, int64_t start, + int64_t end, int64_t step) { + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!"); @@ -217,13 +220,13 @@ std::vector compute_shape_slice( } auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to slice due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to slice due to lack of scalar type!"); return {Shape(scalar_type.value(), dims)}; } -torch::jit::Value* -GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { +torch::jit::Value *GenerateClone(torch::jit::Value *val, + TorchMlirFunction function) { std::vector clone_arguments; clone_arguments.emplace_back(val); @@ -234,20 +237,19 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { return cloned.front(); } -void GenerateCopy( - torch::jit::Value* destination, torch::jit::Value* source, - TorchMlirFunction function) { +void GenerateCopy(torch::jit::Value *destination, torch::jit::Value *source, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(destination); arguments.emplace_back(source); - LowerBuiltin( - at::aten::copy_, c10::ArrayRef(compute_shape_copy(source->type())), - function, arguments); + LowerBuiltin(at::aten::copy_, + c10::ArrayRef(compute_shape_copy(source->type())), + function, arguments); } -torch::jit::Value* GenerateSlice( - torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, - int64_t step, TorchMlirFunction function) { +torch::jit::Value *GenerateSlice(torch::jit::Value *base, int64_t dim, + int64_t start, int64_t end, int64_t step, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(base); arguments.emplace_back(dim); @@ -255,11 +257,11 @@ torch::jit::Value* GenerateSlice( arguments.emplace_back(end); arguments.emplace_back(step); - TorchMlirOpVector selected = LowerBuiltin( - at::aten::slice, - c10::ArrayRef( - compute_shape_slice(base->type(), dim, start, end, step)), - function, arguments); + TorchMlirOpVector selected = + LowerBuiltin(at::aten::slice, + c10::ArrayRef(compute_shape_slice(base->type(), dim, + start, end, step)), + function, arguments); TORCH_CHECK_EQ(selected.size(), 1); return selected.front(); } @@ -267,10 +269,10 @@ torch::jit::Value* GenerateSlice( // Node Lowerings // Default Node Lowering -TorchMlirOpVector TorchMlirNode::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector TorchMlirNode::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; - for (const torch::lazy::Output& output : operands()) { + for (const torch::lazy::Output &output : operands()) { arguments.emplace_back(loctx->GetOutputOp(output)); } return LowerBuiltin(this, function, arguments); @@ -280,19 +282,19 @@ TorchMlirOpVector TorchMlirNode::Lower( // Non-native nodes -TorchMlirOpVector -Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Cast::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(dtype); return LowerBuiltin(at::aten::to, shapes(), function, arguments); } -TorchMlirOpVector DeviceData::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector DeviceData::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto infoptr = data_->info(); auto deviceDataInfoPtr = - (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + (torch::lazy::LazyGraphExecutor::DeviceDataInfo *)infoptr; if (GRAPH_DUMP_ENABLED) { LOG(ERROR) << "Lowering device data node, tensor id " << deviceDataInfoPtr->tensor_id << std::endl; @@ -300,8 +302,8 @@ TorchMlirOpVector DeviceData::Lower( return {loctx->GetParameter(data_)}; } -TorchMlirOpVector Scalar::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Scalar::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto options = at::TensorOptions() .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) @@ -309,8 +311,8 @@ TorchMlirOpVector Scalar::Lower( return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; } -TorchMlirOpVector Expand::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Expand::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(size); diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h index f9e028a5c..650bed045 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h @@ -18,14 +18,14 @@ namespace torch { namespace lazy { -typedef std::vector TorchMlirOpVector; +typedef std::vector TorchMlirOpVector; typedef std::shared_ptr TorchMlirFunction; TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( TorchMlirFunction function, c10::Symbol sym, const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments = {}); + const std::vector &arguments, + const std::vector &kwarguments = {}); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp index b4271df66..c4255068f 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp @@ -2,18 +2,16 @@ #include -#include "device_data.h" #include "../backend_impl.h" +#include "device_data.h" namespace torch { namespace lazy { DeviceData::DeviceData(std::shared_ptr data) - : TorchMlirNode( - ClassOpKind(), - data->shape(), - /*num_outputs=*/1, - /*hash_seed=*/static_cast(101)), + : TorchMlirNode(ClassOpKind(), data->shape(), + /*num_outputs=*/1, + /*hash_seed=*/static_cast(101)), data_(std::move(data)) { propagate_name(); } @@ -21,9 +19,11 @@ DeviceData::DeviceData(std::shared_ptr data) void DeviceData::propagate_name() { if (data_ && name_ != "") { // Add device data name to backend data - TorchMlirBackendData* mlir_data = dynamic_cast(data_.get()); + TorchMlirBackendData *mlir_data = + dynamic_cast(data_.get()); TORCH_CHECK(mlir_data); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); info->name = name_; } @@ -34,7 +34,7 @@ void DeviceData::SetData(std::shared_ptr data) { propagate_name(); } -void DeviceData::SetName(const std::string& name) { +void DeviceData::SetName(const std::string &name) { name_ = name; propagate_name(); } @@ -43,12 +43,12 @@ std::string DeviceData::ToString() const { std::stringstream ss; ss << TorchMlirNode::ToString() << ", device=" << data_->device(); if (name_ != "") { - ss << ", name=" << name_; + ss << ", name=" << name_; } return ss.str(); } -const DeviceData* DeviceData::Cast(const Node* node) { +const DeviceData *DeviceData::Cast(const Node *node) { return NodeCast(node); } @@ -59,7 +59,7 @@ NodePtr DeviceData::Create(std::shared_ptr data) { // Ditching the old data_ is safe because tracing is done iteration // by iteration, and after we lauch the async device execution for the // previous iteration, data_ in DeviceData nodes are not needed anymore. - DeviceData* device_data = static_cast(node.get()); + DeviceData *device_data = static_cast(node.get()); device_data->SetData(data); return node; } diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h index ad9d9d0eb..6f96d0749 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h @@ -6,15 +6,12 @@ #include #include - namespace torch { namespace lazy { class TORCH_API DeviceData : public TorchMlirNode { - public: - static OpKind ClassOpKind() { - return ltc_device_data; - } +public: + static OpKind ClassOpKind() { return ltc_device_data; } explicit DeviceData(std::shared_ptr data); @@ -27,22 +24,23 @@ class TORCH_API DeviceData : public TorchMlirNode { std::string ToString() const override; - const std::shared_ptr& data() const { return data_; } + const std::shared_ptr &data() const { return data_; } void SetData(std::shared_ptr data); - TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override; + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; - static const DeviceData* Cast(const Node* node); + static const DeviceData *Cast(const Node *node); // To reuse IR nodes, use this method to create DeviceData nodes // instead of calling the constructor directly. static NodePtr Create(std::shared_ptr data); - const std::string& GetName() const { return name_; } - void SetName(const std::string& name); + const std::string &GetName() const { return name_; } + void SetName(const std::string &name); - private: +private: void propagate_name(); std::shared_ptr data_; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp index 1df8be231..17e578946 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp @@ -15,12 +15,8 @@ namespace torch { namespace lazy { -Generic::Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs, - hash_t hash_seed) +Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed), hash_seed_(hash_seed) {} diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.h b/projects/ltc/csrc/base_lazy_backend/ops/generic.h index f294b1cfa..01794355a 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.h @@ -23,15 +23,11 @@ namespace lazy { // captured by the LowerFn), but they should instead create a dedicated IR node. // Doing the former would limit IR introspection. class TORCH_API Generic : public TorchMlirNode { - public: - Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs = 1, - hash_t hash_seed = static_cast(0x5a2d296e9)); +public: + Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1, + hash_t hash_seed = static_cast(0x5a2d296e9)); - private: +private: hash_t hash_seed_; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp index 34af3e590..ffa2f06bb 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -IndexTensor::IndexTensor(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - std::vector&& shapes) +IndexTensor::IndexTensor(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + std::vector &&shapes) : torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(), OpList{self, indices}, std::move(shapes), /* num_outputs */ 1, torch::lazy::MHash()) {} @@ -25,13 +25,13 @@ std::string IndexTensor::ToString() const { return ss.str(); } -bool IndexTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const { +bool IndexTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const { return false; } TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -49,10 +49,10 @@ TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, return index_out; } -IndexPut::IndexPut(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes) +IndexPut::IndexPut(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes) : torch::lazy::TorchMlirNode( IndexPut::ClassOpKind(), OpList{self, indices, values}, std::move(shapes), @@ -66,15 +66,15 @@ std::string IndexPut::ToString() const { return ss.str(); } -bool IndexPut::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, +bool IndexPut::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const { return false; } TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -95,5 +95,5 @@ TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, return index_out; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.h b/projects/ltc/csrc/base_lazy_backend/ops/index.h index e97760fc3..6f63cbc68 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.h @@ -15,44 +15,44 @@ namespace torch { namespace lazy { class IndexTensor : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index); } - IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices, - std::vector&& shapes); + IndexTensor(const torch::lazy::Value &self, const torch::lazy::Value &indices, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; }; class IndexPut : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index_put); } - IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes); + IndexPut(const torch::lazy::Value &self, const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; bool accumulate; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp index 0653e4467..e3db5ca37 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp @@ -15,7 +15,7 @@ namespace torch { namespace lazy { -IValueConstant::IValueConstant(const c10::IValue& value) +IValueConstant::IValueConstant(const c10::IValue &value) : torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{}, std::vector{}, /* num_outputs */ 1, torch::lazy::MHash()), @@ -28,9 +28,9 @@ std::string IValueConstant::ToString() const { } TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { return {loctx->graph()->insertConstant(value)}; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h index 8f488ff47..48fb95b73 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h @@ -18,20 +18,20 @@ namespace lazy { // parameter which is helpful in different usecases when we need custom // native ops lowering to torch-mlir IR nodes. class IValueConstant : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::prim::Constant); } - IValueConstant(const c10::IValue& value); + IValueConstant(const c10::IValue &value); std::string ToString() const override; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; c10::IValue value; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp index d20d298df..91cbd2a52 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp @@ -13,10 +13,10 @@ namespace torch { namespace lazy { SplitWithSizesCopy::SplitWithSizesCopy( - const torch::lazy::Value& self, const ::std::vector& split_sizes, - const int64_t& dim, std::vector&& shapes) + const torch::lazy::Value &self, const ::std::vector &split_sizes, + const int64_t &dim, std::vector &&shapes) : torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(), - OpList{ self }, std::move(shapes), + OpList{self}, std::move(shapes), split_sizes.size() /* num_outputs */, torch::lazy::MHash(split_sizes, dim)), split_sizes(split_sizes), dim(dim) {} @@ -29,15 +29,15 @@ std::string SplitWithSizesCopy::ToString() const { return ss.str(); } -bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const { +bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitWithSizesCopy::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -55,13 +55,13 @@ SplitWithSizesCopy::Lower(TorchMlirFunction function, return split_with_sizes_copy_out; } -SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim, - std::vector&& shapes, +SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim, + std::vector &&shapes, const size_t num_outputs) : torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(), - OpList{ self, split_size }, std::move(shapes), + OpList{self, split_size}, std::move(shapes), num_outputs, torch::lazy::MHash(dim)), dim(dim) {} @@ -72,15 +72,15 @@ std::string SplitCopyTensor::ToString() const { return ss.str(); } -bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const { +bool SplitCopyTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitCopyTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.h b/projects/ltc/csrc/base_lazy_backend/ops/split.h index 8593d5628..116ddd64a 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.h @@ -20,19 +20,19 @@ public: return torch::lazy::OpKind(at::aten::split_with_sizes_copy); } - SplitWithSizesCopy(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim, - std::vector&& shapes); + SplitWithSizesCopy(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; std::vector split_sizes; int64_t dim; @@ -44,19 +44,19 @@ public: return torch::lazy::OpKind(at::aten::split_copy); } - SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, const int64_t& dim, - std::vector&& shapes, + SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, const int64_t &dim, + std::vector &&shapes, const size_t num_outputs = 1); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h index c6b75baaf..402355031 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h @@ -17,61 +17,65 @@ namespace torch { namespace lazy { - -// This IR was copied from code-generated output, but the entire _to_copy operator -// cannot be trivially code genereated since it is only desirable to capture IR for -// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke -// the aten/eager fallback necessitating directly implementing the right to(device) behavior +// This IR was copied from code-generated output, but the entire _to_copy +// operator cannot be trivially code genereated since it is only desirable to +// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the +// others it is difficult to even invoke the aten/eager fallback necessitating +// directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TorchMlirNode { - public: - ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) - : torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy), - {self}, std::move(shapes), - /* num_outputs */ 1, - torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), +public: + ToCopy(const torch::lazy::Value &self, + const c10::optional &dtype, + const c10::optional &layout, + const c10::optional &device, + const c10::optional &pin_memory, const bool &non_blocking, + const c10::optional &memory_format, + std::vector &&shapes) + : torch::lazy::TorchMlirNode( + torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, + memory_format)), - dtype(dtype), - layout(layout), - device(device), - pin_memory(pin_memory), - non_blocking(non_blocking), - memory_format(memory_format) {} + dtype(dtype), layout(layout), device(device), pin_memory(pin_memory), + non_blocking(non_blocking), memory_format(memory_format) {} std::string ToString() const override { std::stringstream ss; ss << torch::lazy::TorchMlirNode::ToString(); if (dtype.has_value()) { - ss << ", dtype=" << dtype.value(); + ss << ", dtype=" << dtype.value(); } else { - ss << ", dtype=null"; + ss << ", dtype=null"; } if (layout.has_value()) { - ss << ", layout=" << layout.value(); + ss << ", layout=" << layout.value(); } else { - ss << ", layout=null"; + ss << ", layout=null"; } if (device.has_value()) { - ss << ", device=" << device.value(); + ss << ", device=" << device.value(); } else { - ss << ", device=null"; + ss << ", device=null"; } if (pin_memory.has_value()) { - ss << ", pin_memory=" << pin_memory.value(); + ss << ", pin_memory=" << pin_memory.value(); } else { - ss << ", pin_memory=null"; + ss << ", pin_memory=null"; } ss << ", non_blocking=" << non_blocking; if (memory_format.has_value()) { - ss << ", memory_format=" << memory_format.value(); + ss << ", memory_format=" << memory_format.value(); } else { - ss << ", memory_format=null"; + ss << ", memory_format=null"; } return ss.str(); } - torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function, - torch::lazy::TorchMlirLoweringContext* loctx) const override { - std::vector arguments; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + torch::lazy::TorchMlirLoweringContext *loctx) const override { + std::vector arguments; std::vector kwarguments; arguments.reserve(1); kwarguments.reserve(6); @@ -83,11 +87,12 @@ class ToCopy : public torch::lazy::TorchMlirNode { kwarguments.emplace_back("pin_memory", pin_memory); kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); - torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + torch::lazy::TorchMlirOpVector _to_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), + arguments, kwarguments); TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; - } c10::optional dtype; @@ -97,5 +102,5 @@ class ToCopy : public torch::lazy::TorchMlirNode { bool non_blocking; c10::optional memory_format; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp index a5526366c..c43c84d24 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes) - : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self }, +UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes) + : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{self}, std::move(shapes), self.shape().size(dim), /* num_outputs */ torch::lazy::MHash(dim)), @@ -27,13 +27,13 @@ std::string UnbindCopyInt::ToString() const { return ss.str(); } -bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self, - const int64_t& dim) const { +bool UnbindCopyInt::CanBeReused(const torch::lazy::Value &self, + const int64_t &dim) const { return false; } TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h index 766752c16..9d6d83842 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h @@ -20,15 +20,15 @@ public: return torch::lazy::OpKind(at::aten::unbind_copy); } - UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes); + UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 325e89e14..8e3b2c070 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -21,21 +21,20 @@ namespace lazy { // TODO(henrytu): Upstream these shape inference functions to PyTorch in the // future. -std::vector compute_shape_add(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_add(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } - -std::vector compute_shape_sub(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_sub(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_div(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_div(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -85,7 +84,7 @@ compute_shape_quantize_per_tensor(const at::Tensor &self, double scale, return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_isinf(const at::Tensor& self) { +std::vector compute_shape_isinf(const at::Tensor &self) { return {Shape(at::kBool, self.sizes().vec())}; } @@ -96,9 +95,8 @@ std::vector compute_shape_quantize_per_channel( } std::vector compute_shape_max_pool3d_with_indices( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, - bool ceil_mode) { + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { auto in_sizes = self.sizes().vec(); std::vector dhw(3, 0); std::vector paddings = padding.vec(); @@ -106,18 +104,19 @@ std::vector compute_shape_max_pool3d_with_indices( std::vector dilations = dilation.vec(); std::vector strides = stride.vec(); TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ", - in_sizes); - TORCH_CHECK(kernel_size.size() == 3 && - stride.size() == 3 && - padding.size() == 3 && - dilation.size() == 3, "max_pool3d requires 3D operands, but got ", - kernel_size, stride, padding, dilation); + in_sizes); + TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 && + padding.size() == 3 && dilation.size() == 3, + "max_pool3d requires 3D operands, but got ", kernel_size, stride, + padding, dilation); int64_t batch = in_sizes[0]; int64_t channel = in_sizes[1]; // NCDHW // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html - for (auto i = 0UL; i<3; ++i) { - double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] * - (ksizes[i] - 1) - 1) / (double)strides[i] + 1; + for (auto i = 0UL; i < 3; ++i) { + double out_size = (in_sizes[2 + i] + 2 * paddings[i] - + dilations[i] * (ksizes[i] - 1) - 1) / + (double)strides[i] + + 1; if (ceil_mode) dhw[i] = (int64_t)std::ceil(out_size); else @@ -129,52 +128,54 @@ std::vector compute_shape_max_pool3d_with_indices( } std::vector compute_shape_max_pool3d_with_indices_backward( - const at::Tensor & grad_output, const at::Tensor & self, + const at::Tensor &grad_output, const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, - const at::Tensor & indices) { + const at::Tensor &indices) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mse_loss_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& target, int64_t reduction) { +std::vector +compute_shape_mse_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mul(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_mul(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_var( - const at::Tensor& self, at::OptionalIntArrayRef dim, - const c10::optional & correction, bool keepdim) { +std::vector +compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim, + const c10::optional &correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_nan_to_num( - const at::Tensor & self, c10::optional nan, - c10::optional posinf, c10::optional neginf) { +std::vector +compute_shape_nan_to_num(const at::Tensor &self, c10::optional nan, + c10::optional posinf, + c10::optional neginf) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_hardtanh( - const at::Tensor& self, const at::Scalar& min_val, - const at::Scalar& max_val) { +std::vector +compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val, + const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_hardtanh_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Scalar& min_val, const at::Scalar& max_val) { + const at::Tensor &grad_output, const at::Tensor &self, + const at::Scalar &min_val, const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_where(const at::Tensor& condition, - const at::Tensor& self, - const at::Tensor& other) { +std::vector compute_shape_where(const at::Tensor &condition, + const at::Tensor &self, + const at::Tensor &other) { // There are cases like - // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. @@ -201,32 +202,32 @@ std::vector compute_shape_where(const at::Tensor& condition, return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_bucketize( - const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, - bool right) { +std::vector +compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries, + bool out_int32, bool right) { auto dtype = out_int32 ? at::kInt : at::kLong; return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_copy(const at::Tensor& self, - const at::Tensor& src, +std::vector compute_shape_copy(const at::Tensor &self, + const at::Tensor &src, bool non_blocking) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_floor_divide( - const at::Tensor& self, const at::Tensor& other) { +std::vector +compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fmod(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_fmod(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, + const at::Tensor &input, const c10::optional &weight, + const c10::optional &bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { TORCH_CHECK(input.sizes().size() >= 2, @@ -244,9 +245,10 @@ std::vector compute_shape_native_group_norm( return shapes; } -std::vector compute_shape_im2col( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { +std::vector +compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride) { auto self_meta = at::native::empty_strided_meta_symint( self.sym_sizes(), self.sym_strides(), @@ -260,8 +262,8 @@ std::vector compute_shape_im2col( } std::vector compute_shape_native_group_norm_backward( - const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, - const at::Tensor& rstd, const c10::optional& weight, int64_t N, + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const c10::optional &weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { TORCH_CHECK(input.sizes().size() >= 2, @@ -280,8 +282,8 @@ std::vector compute_shape_native_group_norm_backward( return shapes; } -std::vector compute_shape_remainder( - const at::Tensor& self, const at::Scalar& other) { +std::vector +compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -313,21 +315,22 @@ compute_shape_reflection_pad2d(const at::Tensor &self, return {Shape(self.scalar_type(), out_sizes)}; } -std::vector compute_shape_uniform( - const at::Tensor& self, double from, double to, - c10::optional generator) { +std::vector +compute_shape_uniform(const at::Tensor &self, double from, double to, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_normal_functional( - const at::Tensor& self, double mean, double std, - c10::optional generator) { +std::vector +compute_shape_normal_functional(const at::Tensor &self, double mean, double std, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_multinomial( - const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { +std::vector +compute_shape_multinomial(const at::Tensor &self, int64_t num_samples, + bool replacement, + c10::optional generator) { // Input tensor can be either 1D or 2D. The last dim of output // should be 'num_samples'. So the output shape can be either // [num_samples] or [m, num_samples]. @@ -337,35 +340,38 @@ std::vector compute_shape_multinomial( return {Shape(at::kLong, ishape)}; } -std::vector compute_shape_eye( - int64_t n, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_eye( - int64_t n, int64_t m, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, int64_t m, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_arange( - const at::Scalar& end, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_arange(const at::Scalar &end, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, + const at::Scalar &start, const at::Scalar &end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), @@ -374,7 +380,7 @@ std::vector compute_shape_arange( } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + const at::Scalar &start, const at::Scalar &end, const at::Scalar &step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, step, dtype, layout, @@ -383,34 +389,37 @@ std::vector compute_shape_arange( } std::vector compute_shape_full( - at::IntArrayRef size, const at::Scalar& fill_value, + at::IntArrayRef size, const at::Scalar &fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_ones( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_ones(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_zeros( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_zeros(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_empty( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { +std::vector +compute_shape_empty(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -423,20 +432,21 @@ std::vector compute_shape_empty_strided( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Scalar& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Scalar &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Tensor& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Tensor &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_randn( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_randn(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -457,36 +467,39 @@ std::vector compute_shape_randint( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_resize( - const at::Tensor & self, at::IntArrayRef size, - c10::optional memory_format) { +std::vector +compute_shape_resize(const at::Tensor &self, at::IntArrayRef size, + c10::optional memory_format) { return {Shape(self.scalar_type(), size.vec())}; } -std::vector compute_shape_bernoulli( - const at::Tensor& self, const at::Tensor &p, - c10::optional generator) { +std::vector +compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_scalar_tensor( - const at::Scalar & s, c10::optional dtype, + const at::Scalar &s, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } -std::vector compute_shape_roll( - const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { +std::vector compute_shape_roll(const at::Tensor &self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { - auto out_meta = - at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); +std::vector compute_shape_linspace( + const at::Scalar &start, const at::Scalar &end, int64_t steps, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::linspace(start, end, steps, dtype, layout, + c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } - -} // namespace lazy +} // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.cpp b/projects/ltc/csrc/base_lazy_backend/tensor.cpp index 82ae6cc27..5be4ab369 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.cpp +++ b/projects/ltc/csrc/base_lazy_backend/tensor.cpp @@ -14,16 +14,16 @@ namespace torch { namespace lazy { -at::Tensor CreateFunctionalizedAtenFromLtcTensor( - const LazyTensorPtr& ltc_tensor) { +at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor) { at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); if (!c10::impl::tls_is_dispatch_key_excluded( - c10::DispatchKey::Functionalize) && + c10::DispatchKey::Functionalize) && !at::functionalization::impl::isFunctionalTensor(tensor)) { return at::functionalization::impl::to_functional_tensor(tensor); } return tensor; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.h b/projects/ltc/csrc/base_lazy_backend/tensor.h index 4e39dd095..18e63ef68 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.h +++ b/projects/ltc/csrc/base_lazy_backend/tensor.h @@ -18,7 +18,8 @@ namespace lazy { // should have explicit tensor functinoalization. Otherwise we can get // unfanctionalized primitives or in the worst case if we apply inplace // operations to unfunctionalized tensor it won't be captured in LTC graph. -TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); +TORCH_API at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/exception.h b/projects/ltc/csrc/base_lazy_backend/utils/exception.h index 96510d830..533677ad8 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/exception.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/exception.h @@ -21,8 +21,8 @@ } #define UNIMPLEMENTED_FUNCTION_ERROR() \ - UNIMPLEMENTED_ERROR( \ - "\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__) + UNIMPLEMENTED_ERROR("\n\t" << __FILE__ << ":" << __LINE__ << " " \ + << __PRETTY_FUNCTION__) #define UNSUPPORTED_ERROR(msg) \ { \ diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp index 9ca8b666a..a4f367371 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp @@ -7,9 +7,9 @@ namespace torch { namespace jit { -void ConvertScalarImplicit(std::shared_ptr& graph) { +void ConvertScalarImplicit(std::shared_ptr &graph) { DepthFirstGraphNodeIterator it(graph); - for (auto* node = it.next(); node != nullptr; node = it.next()) { + for (auto *node = it.next(); node != nullptr; node = it.next()) { if (node->kind() != c10::aten::ScalarImplicit) { continue; } @@ -27,15 +27,13 @@ void ConvertScalarImplicit(std::shared_ptr& graph) { node_type = c10::aten::FloatImplicit; output_type = FloatType::get(); } else { - throw std::runtime_error( - "Expected isIntegralType or isFloatingType"); + throw std::runtime_error("Expected isIntegralType or isFloatingType"); } - Value * output = graph - ->create(node_type, {input}) - ->insertBefore(node) - ->output() - ->setType(output_type); + Value *output = graph->create(node_type, {input}) + ->insertBefore(node) + ->output() + ->setType(output_type); node->output()->replaceAllUsesWith(output); node->destroy(); } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h index 2c4214cfc..d9e47b464 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h @@ -4,7 +4,7 @@ namespace torch { namespace jit { // Convert ScalarImplicit to IntImplicit or FloatImplicit. -TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); +TORCH_API void ConvertScalarImplicit(std::shared_ptr &graph); } // namespace jit } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h index 281331992..a5a524b05 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h @@ -1,49 +1,49 @@ #pragma once -#include #include +#include #include - template -std::ostream& string_join(std::ostream& out, const std::vector& v, const std::string& delimiter) { - size_t i = 0; - for (const T& e : v) { - if ((i++) > 0) { out << delimiter; } - out << e; +std::ostream &string_join(std::ostream &out, const std::vector &v, + const std::string &delimiter) { + size_t i = 0; + for (const T &e : v) { + if ((i++) > 0) { + out << delimiter; } - return out; + out << e; + } + return out; } template -std::string string_join(const std::vector& v, const std::string& delimiter) { - std::ostringstream joined; - string_join(joined, v, delimiter); - return joined.str(); +std::string string_join(const std::vector &v, const std::string &delimiter) { + std::ostringstream joined; + string_join(joined, v, delimiter); + return joined.str(); } -inline std::vector string_split( - const std::string& str, - const std::string& sep -) { - std::vector tokens; - std::size_t pos1 = str.find_first_not_of(sep); - while (pos1 != std::string::npos) { - std::size_t pos2 = str.find_first_of(sep, pos1); - if (pos2 == std::string::npos) { - tokens.push_back(str.substr(pos1)); - pos1 = pos2; - } else { - tokens.push_back(str.substr(pos1, pos2 - pos1)); - pos1 = str.find_first_not_of(sep, pos2 + 1); - } +inline std::vector string_split(const std::string &str, + const std::string &sep) { + std::vector tokens; + std::size_t pos1 = str.find_first_not_of(sep); + while (pos1 != std::string::npos) { + std::size_t pos2 = str.find_first_of(sep, pos1); + if (pos2 == std::string::npos) { + tokens.push_back(str.substr(pos1)); + pos1 = pos2; + } else { + tokens.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = str.find_first_not_of(sep, pos2 + 1); } - return tokens; + } + return tokens; } /* * Returns true if str starts with prefix */ -inline bool startswith(const std::string& str, const std::string& prefix) { - return str.rfind(prefix, 0) == 0; +inline bool startswith(const std::string &str, const std::string &prefix) { + return str.rfind(prefix, 0) == 0; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h index 5ae149049..f6c51ba61 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h @@ -6,24 +6,25 @@ namespace sys_util { template -static T GetEnv(const std::string& name, const T& default_value = T(0)) { - const char* env = std::getenv(name.c_str()); +static T GetEnv(const std::string &name, const T &default_value = T(0)) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return T(std::atoi(env)); } -static std::string GetEnvString(const std::string& name, const std::string& default_value) { - const char* env = std::getenv(name.c_str()); +static std::string GetEnvString(const std::string &name, + const std::string &default_value) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return std::string(env); } -static bool GetEnvBool(const char* name, bool defval) { - const char* env = std::getenv(name); +static bool GetEnvBool(const char *name, bool defval) { + const char *env = std::getenv(name); if (env == nullptr) { return defval; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp index cdd971680..71a0e89f4 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -3,84 +3,90 @@ #include "../generated/LazyIr.h" #include "../mlir_node.h" - namespace torch { namespace lazy { -bool is_detach_copy(const torch::lazy::Node* node) { - return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +bool is_detach_copy(const torch::lazy::Node *node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); } -bool is_detach_copy(const torch::lazy::Value& value) { - return is_detach_copy(value.node.get()); +bool is_detach_copy(const torch::lazy::Value &value) { + return is_detach_copy(value.node.get()); } -torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { - if (!node) { return nullptr; } - - torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; -} - -const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { - if (!node) { return nullptr; } - - const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; -} - - -torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } +torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *node) { + if (!node) { return nullptr; -} -const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } - return nullptr; -} -torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { - if (!value) { - return nullptr; - } - return device_data_cast(value.node.get()); + } + + torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } -torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device -) { - if (!device) { - device = torch::lazy::GetBackendDevice(tensor); - } - TORCH_CHECK(device); - torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); - if (lazy_tensor) { - return device_data_cast(lazy_tensor->GetIrValue()); - } +const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *node) { + if (!node) { return nullptr; + } + + const torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } -} // namespace lazy -} // namespace torch +torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} +const torch::lazy::DeviceData *device_data_cast(const torch::lazy::Node *node) { + if (!node) { + return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; +} +torch::lazy::DeviceData *device_data_cast(const torch::lazy::Value &value) { + if (!value) { + return nullptr; + } + return device_data_cast(value.node.get()); +} + +torch::lazy::DeviceData * +device_data_cast(const at::Tensor &tensor, + c10::optional device) { + if (!device) { + device = torch::lazy::GetBackendDevice(tensor); + } + TORCH_CHECK(device); + torch::lazy::LazyTensorPtr lazy_tensor = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); + if (lazy_tensor) { + return device_data_cast(lazy_tensor->GetIrValue()); + } + return nullptr; +} + +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h index 745be78c3..f8e5e3172 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,18 +8,21 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Node*); -TORCH_API bool is_detach_copy(const torch::lazy::Value&); +TORCH_API bool is_detach_copy(const torch::lazy::Node *); +TORCH_API bool is_detach_copy(const torch::lazy::Value &); -TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); -TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); +TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *); +TORCH_API const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *); -TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); -TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); -TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); -TORCH_API torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device = c10::nullopt -); +TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *); +TORCH_API const torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Node *); +TORCH_API torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Value &value); +TORCH_API torch::lazy::DeviceData *device_data_cast( + const at::Tensor &tensor, + c10::optional device = c10::nullopt); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 4bcb9347b..8708ff06a 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -30,7 +30,7 @@ namespace lazy { /// Returns true if a string begins with another. inline bool beginswith(const std::string& s, const std::string& t) { - return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; } struct ReferenceLazyBackendDeviceType : public BackendDeviceType { @@ -73,10 +73,8 @@ public: // Vendor backend specific lowering can be exec here before returning. for (const auto& instance : instances) { TORCH_CHECK( - instance->in_mark_step, - "Compile outside of mark step:\n", - GetComputationBackendText(instance) - ); + instance->in_mark_step, "Compile outside of mark step:\n", + GetComputationBackendText(instance)); // Store computation instance for external access after compilation. GetLatestComputation() = instance; } @@ -114,16 +112,17 @@ public: // Convert any lazy devices to cpu devices to ensure // that the values are actually computed if (node->outputs().size() == 1 && - node->output()->type()->kind() == - c10::TypeKind::DeviceObjType) { - auto value_sym = torch::jit::Symbol::attr("value"); - TORCH_CHECK(node->hasAttribute(value_sym), - "Expected node to have 'value' attribute."); - TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, - "Expected 'value' attribute to be a string."); - if (beginswith(node->s(value_sym), "lazy")) { - node->s_(value_sym, "cpu"); - } + node->output()->type()->kind() == c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK( + node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK( + node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } } } @@ -132,7 +131,8 @@ public: for (const auto& argument : arguments) { const auto mlir_data = std::static_pointer_cast(argument); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto* info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); if (info->scalar.has_value()) { stack.emplace_back(info->scalar.value()); diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index f4b8cd9ba..2cbb6d6f1 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch/csrc/jit/python/pybind.h" -#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" +#include "torch/csrc/lazy/core/config.h" #include #include @@ -56,8 +56,8 @@ void Initialize() { } if (ir_debug) { - FLAGS_torch_lazy_ir_debug = true; - std::cout << "Enabled lazy tensor IR debugging." << std::endl; + FLAGS_torch_lazy_ir_debug = true; + std::cout << "Enabled lazy tensor IR debugging." << std::endl; } } @@ -82,15 +82,17 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) { torch::lazy::GetLatestComputation().get()); return py::cast(computation); }); - m.def("set_parameter_name", - [](const at::Tensor& tensor, const std::string& name) -> bool { - torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor); - if (ir_node) { - ir_node->SetName(name); - return true; - } - return false; - }); + m.def( + "set_parameter_name", + [](const at::Tensor& tensor, const std::string& name) -> bool { + torch::lazy::DeviceData* ir_node = + torch::lazy::device_data_cast(tensor); + if (ir_node) { + ir_node->SetName(name); + return true; + } + return false; + }); m.def("_initialize", []() { NoGilSection gil; Initialize();