mirror of https://github.com/llvm/torch-mlir
Clang format refresh (#2812)
After noticing a number of commits with unrelated formatting changes, I think something was changed with clang-format at one point and we're seeing a number of unrelated changes. Doing a refresh can help avoid this. The changes made here came from ``` find lib -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find include -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find projects -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm ```pull/2823/head
parent
d3fd754b93
commit
494089d53d
|
@ -22,4 +22,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Torch, torch);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // TORCHMLIR_C_DIALECTS_H
|
#endif // TORCHMLIR_C_DIALECTS_H
|
||||||
|
|
|
@ -10,9 +10,9 @@
|
||||||
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||||
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
|
||||||
|
|
||||||
#include "mlir/IR/IRMapping.h"
|
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
|
||||||
|
|
|
@ -78,8 +78,8 @@ struct OpBinder {
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult tensorOperandsList( llvm::SmallVectorImpl<Value> &values) {
|
ParseResult tensorOperandsList(llvm::SmallVectorImpl<Value> &values) {
|
||||||
for (uint32_t i = 0; i < op->getNumOperands(); i++) {
|
for (uint32_t i = 0; i < op->getNumOperands(); i++) {
|
||||||
values.push_back(op->getOperand(i));
|
values.push_back(op->getOperand(i));
|
||||||
}
|
}
|
||||||
|
@ -97,7 +97,8 @@ struct OpBinder {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) {
|
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
|
||||||
|
int64_t idx) {
|
||||||
if (idx >= op->getNumResults())
|
if (idx >= op->getNumResults())
|
||||||
return failure();
|
return failure();
|
||||||
auto t = toValidTensorType(op->getResult(idx).getType());
|
auto t = toValidTensorType(op->getResult(idx).getType());
|
||||||
|
|
|
@ -37,33 +37,31 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
|
||||||
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
|
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This specialization is for Div op. Unlike other binary ops, it doesn't support
|
// This specialization is for Div op. Unlike other binary ops, it doesn't
|
||||||
// floating type.
|
// support floating type.
|
||||||
template <>
|
template <>
|
||||||
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
Operation *op, TensorType outType,
|
Operation *op, TensorType outType,
|
||||||
Value lhs, Value rhs);
|
Value lhs, Value rhs);
|
||||||
|
|
||||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
Operation *op,
|
Operation *op,
|
||||||
Value params_value,
|
Value params_value,
|
||||||
Value index_value,
|
Value index_value,
|
||||||
int32_t axis);
|
int32_t axis);
|
||||||
|
|
||||||
// Lowers torch.aten.Gather operators to a sequence of TOSA ops.
|
// Lowers torch.aten.Gather operators to a sequence of TOSA ops.
|
||||||
// Revised from
|
// Revised from
|
||||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
|
||||||
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
Operation *op, Type out_type,
|
Type out_type, Value params_value,
|
||||||
Value params_value,
|
Value indices_value);
|
||||||
Value indices_value);
|
|
||||||
|
|
||||||
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||||
Operation *op, Type outType,
|
Operation *op, Type outType,
|
||||||
Value paramsValue, Value indicesValue,
|
Value paramsValue, Value indicesValue,
|
||||||
Value fillValues);
|
Value fillValues);
|
||||||
|
|
||||||
|
|
||||||
// Lowers ReduceAll to a sequence of TOSA ops.
|
// Lowers ReduceAll to a sequence of TOSA ops.
|
||||||
std::optional<Value>
|
std::optional<Value>
|
||||||
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
|
@ -67,7 +67,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType);
|
||||||
// op. This allows shape inference during the framework to TOSA lowering.
|
// op. This allows shape inference during the framework to TOSA lowering.
|
||||||
template <typename TosaOp, typename... Args>
|
template <typename TosaOp, typename... Args>
|
||||||
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||||
Args &&... args) {
|
Args &&...args) {
|
||||||
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
|
auto op = rewriter.create<TosaOp>(loc, result_ty, args...);
|
||||||
|
|
||||||
InferShapedTypeOpInterface shapeInterface =
|
InferShapedTypeOpInterface shapeInterface =
|
||||||
|
@ -111,7 +111,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty,
|
||||||
|
|
||||||
template <typename TosaOp, typename... Args>
|
template <typename TosaOp, typename... Args>
|
||||||
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
|
void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
|
||||||
Type result_ty, Args &&... args) {
|
Type result_ty, Args &&...args) {
|
||||||
auto result =
|
auto result =
|
||||||
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
|
CreateOpAndInfer<TosaOp>(rewriter, op->getLoc(), result_ty, args...);
|
||||||
rewriter.replaceOp(op, result->getResults());
|
rewriter.replaceOp(op, result->getResults());
|
||||||
|
@ -119,7 +119,7 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
||||||
// Get accumulator type for AvgPool2dOp.
|
// Get accumulator type for AvgPool2dOp.
|
||||||
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
||||||
TypeAttr &accType);
|
TypeAttr &accType);
|
||||||
|
|
||||||
} // namespace tosa
|
} // namespace tosa
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
|
@ -36,8 +36,7 @@ class HasValueSemantics
|
||||||
// This is a weaker form of HasValueSemantics, since that trait also requires no
|
// This is a weaker form of HasValueSemantics, since that trait also requires no
|
||||||
// aliasing. That is, HasValueSemantics implies this trait.
|
// aliasing. That is, HasValueSemantics implies this trait.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class ReadOnly
|
class ReadOnly : public ::mlir::OpTrait::TraitBase<ConcreteType, ReadOnly> {};
|
||||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, ReadOnly> {};
|
|
||||||
|
|
||||||
// If a Torch op has this trait, it means that the op is a "trailing underscore"
|
// If a Torch op has this trait, it means that the op is a "trailing underscore"
|
||||||
// op variant that performs an in-place operation on its first argument. These
|
// op variant that performs an in-place operation on its first argument. These
|
||||||
|
@ -62,7 +61,8 @@ class AllowsTypeRefinement
|
||||||
// by the IValue importer.
|
// by the IValue importer.
|
||||||
template <typename ConcreteType>
|
template <typename ConcreteType>
|
||||||
class AllowedInModuleInitializer
|
class AllowedInModuleInitializer
|
||||||
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowedInModuleInitializer> {};
|
: public ::mlir::OpTrait::TraitBase<ConcreteType,
|
||||||
|
AllowedInModuleInitializer> {};
|
||||||
|
|
||||||
} // namespace OpTrait
|
} // namespace OpTrait
|
||||||
} // namespace Torch
|
} // namespace Torch
|
||||||
|
|
|
@ -61,7 +61,8 @@ struct TorchLoweringPipelineOptions
|
||||||
|
|
||||||
Option<std::string> extraLibrary{
|
Option<std::string> extraLibrary{
|
||||||
*this, "extra-library",
|
*this, "extra-library",
|
||||||
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
|
llvm::cl::desc("Filename of MLIR module for splicing into the abstract "
|
||||||
|
"interpretation library.")};
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Creates a pipeline that lowers the object graph IR that is produced by
|
/// Creates a pipeline that lowers the object graph IR that is produced by
|
||||||
|
@ -125,8 +126,7 @@ createSimplifyDtypeCalculationsPass();
|
||||||
std::unique_ptr<OperationPass<func::FuncOp>>
|
std::unique_ptr<OperationPass<func::FuncOp>>
|
||||||
createDropAbstractInterpCalculationsPass();
|
createDropAbstractInterpCalculationsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>> createEraseModuleInitializerPass();
|
||||||
createEraseModuleInitializerPass();
|
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
createLowerToBackendContractPass(int maxIterations, bool decompose,
|
||||||
|
|
|
@ -140,12 +140,7 @@ enum Reduction { None, Mean, Sum, END };
|
||||||
// Source:
|
// Source:
|
||||||
// https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h
|
// https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
enum MemoryFormat {
|
enum MemoryFormat { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
|
||||||
Contiguous,
|
|
||||||
Preserve,
|
|
||||||
ChannelsLast,
|
|
||||||
ChannelsLast3d
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Possible values for `layout` argument in PyTorch ops that support it.
|
// Possible values for `layout` argument in PyTorch ops that support it.
|
||||||
|
|
|
@ -121,8 +121,7 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
|
||||||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||||
// converted the to the element type of the given tensor type.
|
// converted the to the element type of the given tensor type.
|
||||||
Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||||
BaseTensorType resultType, Value scalar,
|
BaseTensorType resultType, Value scalar, Value sizeList);
|
||||||
Value sizeList);
|
|
||||||
|
|
||||||
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
|
||||||
// would be converted to the element type of the given `inputType`.
|
// would be converted to the element type of the given `inputType`.
|
||||||
|
|
|
@ -9,7 +9,8 @@
|
||||||
|
|
||||||
#include "torch-mlir-c/Dialects.h"
|
#include "torch-mlir-c/Dialects.h"
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
|
||||||
#include "mlir/CAPI/Registration.h"
|
#include "mlir/CAPI/Registration.h"
|
||||||
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
|
|
||||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, mlir::torch::Torch::TorchDialect)
|
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch,
|
||||||
|
mlir::torch::Torch::TorchDialect)
|
||||||
|
|
|
@ -30,6 +30,4 @@ namespace {
|
||||||
#include "torch-mlir/Conversion/Passes.h.inc"
|
#include "torch-mlir/Conversion/Passes.h.inc"
|
||||||
} // end namespace
|
} // end namespace
|
||||||
|
|
||||||
void mlir::torch::registerConversionPasses() {
|
void mlir::torch::registerConversionPasses() { ::registerPasses(); }
|
||||||
::registerPasses();
|
|
||||||
}
|
|
||||||
|
|
|
@ -82,7 +82,8 @@ public:
|
||||||
// temp = multiplier * currentSeed + incrementStep
|
// temp = multiplier * currentSeed + incrementStep
|
||||||
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||||
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||||
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
|
globalVar =
|
||||||
|
rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
|
||||||
rewriter.create<ml_program::GlobalStoreOp>(
|
rewriter.create<ml_program::GlobalStoreOp>(
|
||||||
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
|
||||||
globalVar);
|
globalVar);
|
||||||
|
|
|
@ -29,7 +29,8 @@ using namespace mlir::torch::onnx_c;
|
||||||
// thing here, so we simplify.
|
// thing here, so we simplify.
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
OnnxCustomOpConversionPattern &patterns) {
|
OnnxCustomOpConversionPattern &patterns) {
|
||||||
patterns.onOp("HardSigmoid", 6,
|
patterns.onOp(
|
||||||
|
"HardSigmoid", 6,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value tensorOperand;
|
Value tensorOperand;
|
||||||
|
@ -39,8 +40,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.f32FloatAttr(beta, "beta", 0.5f) ||
|
binder.f32FloatAttr(beta, "beta", 0.5f) ||
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta))
|
// HardSigmoid computes the following expression:
|
||||||
|
// max(0, min(1, alpha * x + beta))
|
||||||
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||||
rewriter.getF64FloatAttr(alpha));
|
rewriter.getF64FloatAttr(alpha));
|
||||||
|
@ -51,7 +53,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
|
|
||||||
// Expression: alpha * x + beta
|
// Expression: alpha * x + beta
|
||||||
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
|
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
|
||||||
binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha);
|
binder.getLoc(), resultType, tensorOperand, constBeta,
|
||||||
|
/*alpha=*/constAlpha);
|
||||||
|
|
||||||
// Expression: min(1, alpha * x + beta)
|
// Expression: min(1, alpha * x + beta)
|
||||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -100,7 +103,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenLtTensorOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenLtTensorOp>(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("LessOrEqual", 1,
|
patterns.onOp("LessOrEqual", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
@ -109,9 +112,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.tensorResultType(resultType)) {
|
binder.tensorResultType(resultType)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenLeTensorOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenLeTensorOp>(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("Log", 1,
|
patterns.onOp("Log", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
@ -126,7 +129,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("MatMul", 13,
|
patterns.onOp("MatMul", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value lhs, rhs;
|
Value lhs, rhs;
|
||||||
if (binder.tensorOperands(lhs, rhs) ||
|
if (binder.tensorOperands(lhs, rhs) ||
|
||||||
|
@ -206,20 +209,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("Mul", 7,
|
patterns.onOp("Mul", 7,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value lhs, rhs;
|
Value lhs, rhs;
|
||||||
if (binder.tensorOperands(lhs, rhs) ||
|
if (binder.tensorOperands(lhs, rhs) ||
|
||||||
binder.tensorResultType(resultType)) {
|
binder.tensorResultType(resultType)) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenMulTensorOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenMulTensorOp>(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("NonZero", 13,
|
patterns.onOp("NonZero", 13,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
Value operand;
|
Value operand;
|
||||||
if (binder.tensorOperand(operand) ||
|
if (binder.tensorOperand(operand) ||
|
||||||
binder.tensorResultType(resultType)) {
|
binder.tensorResultType(resultType)) {
|
||||||
|
@ -332,41 +335,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp("Max", 1,
|
patterns.onOp(
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
llvm::SmallVector<Value> operands;
|
llvm::SmallVector<Value> operands;
|
||||||
if (binder.tensorOperandsList(operands) ||
|
if (binder.tensorOperandsList(operands) ||
|
||||||
binder.tensorResultType(resultType) ||
|
binder.tensorResultType(resultType) || operands.size() == 0) {
|
||||||
operands.size() == 0) {
|
return failure();
|
||||||
return failure();
|
}
|
||||||
}
|
Value result = operands[0];
|
||||||
Value result = operands[0];
|
for (uint64_t i = 1; i < operands.size(); i++) {
|
||||||
for (uint64_t i = 1; i < operands.size(); i++) {
|
result = rewriter.create<Torch::AtenMaximumOp>(
|
||||||
result = rewriter.create<Torch::AtenMaximumOp>(
|
binder.getLoc(), resultType, result, operands[i]);
|
||||||
binder.getLoc(), resultType, result, operands[i]);
|
}
|
||||||
}
|
rewriter.replaceOp(binder.op, result.getDefiningOp());
|
||||||
rewriter.replaceOp(binder.op, result.getDefiningOp());
|
return success();
|
||||||
return success();
|
});
|
||||||
});
|
patterns.onOp(
|
||||||
patterns.onOp("Min", 1,
|
"Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
Torch::ValueTensorType resultType;
|
||||||
Torch::ValueTensorType resultType;
|
llvm::SmallVector<Value> operands;
|
||||||
llvm::SmallVector<Value> operands;
|
if (binder.tensorOperandsList(operands) ||
|
||||||
if (binder.tensorOperandsList(operands) ||
|
binder.tensorResultType(resultType) || operands.size() == 0) {
|
||||||
binder.tensorResultType(resultType) ||
|
return failure();
|
||||||
operands.size() == 0) {
|
}
|
||||||
return failure();
|
Value result = operands[0];
|
||||||
}
|
for (uint64_t i = 1; i < operands.size(); i++) {
|
||||||
Value result = operands[0];
|
result = rewriter.create<Torch::AtenMinimumOp>(
|
||||||
for (uint64_t i = 1; i < operands.size(); i++) {
|
binder.getLoc(), resultType, result, operands[i]);
|
||||||
result = rewriter.create<Torch::AtenMinimumOp>(
|
}
|
||||||
binder.getLoc(), resultType, result, operands[i]);
|
rewriter.replaceOp(binder.op, result.getDefiningOp());
|
||||||
}
|
return success();
|
||||||
rewriter.replaceOp(
|
});
|
||||||
binder.op, result.getDefiningOp());
|
|
||||||
return success();
|
|
||||||
});
|
|
||||||
patterns.onOp("Neg", 1,
|
patterns.onOp("Neg", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
Torch::ValueTensorType resultType;
|
Torch::ValueTensorType resultType;
|
||||||
|
@ -693,7 +693,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.getLoc(),
|
binder.getLoc(),
|
||||||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||||
cstStrides);
|
cstStrides);
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
Value cstFalse =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
Value cstCeilMode = cstFalse;
|
Value cstCeilMode = cstFalse;
|
||||||
Value cstCountIncludePad = cstFalse;
|
Value cstCountIncludePad = cstFalse;
|
||||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
@ -903,7 +904,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorTensorOp>(
|
rewriter.replaceOpWithNewOp<Torch::AtenPowTensorTensorOp>(
|
||||||
binder.op, resultType, lhs, rhs);
|
binder.op, resultType, lhs, rhs);
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
|
|
|
@ -42,56 +42,63 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
OnnxCustomOpConversionPattern &patterns) {
|
OnnxCustomOpConversionPattern &patterns) {
|
||||||
patterns.onOp("QuantizeLinear", 1,
|
patterns.onOp(
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
"QuantizeLinear", 1,
|
||||||
Torch::ValueTensorType resultType;
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
llvm::SmallVector<Value> operands;
|
Torch::ValueTensorType resultType;
|
||||||
if (binder.tensorOperands(operands, 3) ||
|
llvm::SmallVector<Value> operands;
|
||||||
binder.tensorResultType(resultType))
|
if (binder.tensorOperands(operands, 3) ||
|
||||||
return failure();
|
binder.tensorResultType(resultType))
|
||||||
|
return failure();
|
||||||
|
|
||||||
Value operand = operands[0];
|
Value operand = operands[0];
|
||||||
Value scale = operands[1];
|
Value scale = operands[1];
|
||||||
Value zeropoint = operands[2];
|
Value zeropoint = operands[2];
|
||||||
|
|
||||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
||||||
if (!scaleTy || !scaleTy.hasSizes())
|
if (!scaleTy || !scaleTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||||
"requires known rank");
|
if (!resultType.hasDtype())
|
||||||
if (!resultType.hasDtype())
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
return rewriter.notifyMatchFailure(
|
"requires known result dtype");
|
||||||
binder.op, "requires known result dtype");
|
|
||||||
|
|
||||||
if (scaleTy.getSizes().size() == 0) {
|
if (scaleTy.getSizes().size() == 0) {
|
||||||
Type qTy = resultType.getDtype();
|
Type qTy = resultType.getDtype();
|
||||||
|
|
||||||
if (qTy.isUnsignedInteger(8)) {
|
if (qTy.isUnsignedInteger(8)) {
|
||||||
qTy = rewriter.getType<Torch::QUInt8Type>();
|
qTy = rewriter.getType<Torch::QUInt8Type>();
|
||||||
} else if (qTy.isSignedInteger(8)) {
|
} else if (qTy.isSignedInteger(8)) {
|
||||||
qTy = rewriter.getType<Torch::QInt8Type>();
|
qTy = rewriter.getType<Torch::QInt8Type>();
|
||||||
} else if (qTy.isSignedInteger(32)) {
|
} else if (qTy.isSignedInteger(32)) {
|
||||||
qTy = rewriter.getType<Torch::QInt32Type>();
|
qTy = rewriter.getType<Torch::QInt32Type>();
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype");
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
}
|
"unsupported result dtype");
|
||||||
|
}
|
||||||
|
|
||||||
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(resultType.getOptionalSizes(), qTy);
|
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
auto torchqTy = Torch::getScalarTypeForType(qTy);
|
resultType.getOptionalSizes(), qTy);
|
||||||
|
auto torchqTy = Torch::getScalarTypeForType(qTy);
|
||||||
|
|
||||||
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast<int64_t>(torchqTy)));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||||
|
static_cast<int64_t>(torchqTy)));
|
||||||
|
|
||||||
scale = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
scale = rewriter.create<Torch::AtenItemOp>(
|
||||||
zeropoint = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
|
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
|
||||||
|
zeropoint = rewriter.create<Torch::AtenItemOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
|
||||||
|
|
||||||
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
|
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
|
||||||
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType, quantize);
|
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
|
||||||
return success();
|
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
|
||||||
}
|
binder.op, resultType, quantize);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
return failure();
|
return failure();
|
||||||
});
|
});
|
||||||
patterns.onOp(
|
patterns.onOp(
|
||||||
"QLinearMatMul", 1,
|
"QLinearMatMul", 1,
|
||||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
@ -1245,7 +1252,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert dynamic shape dimension.
|
// Convert dynamic shape dimension.
|
||||||
for (unsigned i = 0; i < shape.size(); i++){
|
for (unsigned i = 0; i < shape.size(); i++) {
|
||||||
if (shape[i] == ShapedType::kDynamic)
|
if (shape[i] == ShapedType::kDynamic)
|
||||||
shape[i] = Torch::kUnknownSize;
|
shape[i] = Torch::kUnknownSize;
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
|
auto rank =
|
||||||
|
rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
|
||||||
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), rank);
|
op, getTypeConverter()->convertType(op.getType()), rank);
|
||||||
return success();
|
return success();
|
||||||
|
@ -74,7 +75,8 @@ public:
|
||||||
matchAndRewrite(AtenOp op,
|
matchAndRewrite(AtenOp op,
|
||||||
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(), adaptor.getB());
|
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
|
||||||
|
adaptor.getB());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -112,10 +114,10 @@ public:
|
||||||
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
|
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value a =
|
Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type());
|
rewriter.getF64Type());
|
||||||
Value b =
|
Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(),
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type());
|
rewriter.getF64Type());
|
||||||
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
|
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -176,15 +178,16 @@ public:
|
||||||
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
||||||
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
||||||
auto shapedType =
|
auto shapedType =
|
||||||
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
|
||||||
auto rawData = elements.getRawData();
|
auto rawData = elements.getRawData();
|
||||||
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
|
DenseElementsAttr newAttr =
|
||||||
shapedType, rawData);
|
DenseElementsAttr::getFromRawBuffer(shapedType, rawData);
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
|
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
if (auto elements =
|
||||||
|
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
||||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
||||||
Type builtinTensorElemTy =
|
Type builtinTensorElemTy =
|
||||||
|
@ -360,7 +363,8 @@ public:
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith> {
|
class ConvertTorchToArith
|
||||||
|
: public ConvertTorchToArithBase<ConvertTorchToArith> {
|
||||||
public:
|
public:
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
||||||
|
|
|
@ -110,22 +110,32 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
// Example:
|
// Example:
|
||||||
// input = tensor([[[0., 1., 2., 3.],
|
// input = tensor([[[0., 1., 2., 3.],
|
||||||
// [4., 5., 6., 7.]]])
|
// [4., 5., 6., 7.]]])
|
||||||
// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1
|
// torch.ops.aten.reflection_pad1d(input, (3,1));
|
||||||
// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
|
// padding_left = 3,
|
||||||
// [7., 6., 5., 4., 5., 6., 7., 6.]]])
|
// padding_right = 1
|
||||||
// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension
|
// output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
|
||||||
// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension.
|
// [7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||||
// The last dimension of the result tensor should be last dimension of input tensor +
|
// Checks: 1) Each of padding_left and padding_right must be non-negative and
|
||||||
// left padding size + right padding size. INitialize result tensor to all zeros
|
// less than the size of the last dimension.
|
||||||
// b) Setup affine map to take slice from input tensor of size left padding starting from
|
// Implementation: a) Construct a result tensor of
|
||||||
// second column onwards as first column is reflection boundary
|
// shape of input tensor except for the last dimension.
|
||||||
|
// The last dimension of the result tensor should be last
|
||||||
|
// dimension of input tensor + left padding size + right
|
||||||
|
// padding size. Initialize result tensor to all zeros
|
||||||
|
// b) Setup affine map to take slice from input tensor of size
|
||||||
|
// left padding starting from
|
||||||
|
// second column onwards as first column is reflection
|
||||||
|
// boundary
|
||||||
// c) Reflect the affine map to have resultant slice reflected
|
// c) Reflect the affine map to have resultant slice reflected
|
||||||
// d) Take the slice and write from begining in result tensor
|
// d) Take the slice and write from begining in result tensor
|
||||||
// e) write the original tensor next into result tensor
|
// e) write the original tensor next into result tensor
|
||||||
// f) Setup affine map to take slice from input tensor of right padding size ending
|
// f) Setup affine map to take slice from input tensor of right
|
||||||
// at second last column as last column is reflection boundary for right padding
|
// padding size ending
|
||||||
|
// at second last column as last column is reflection
|
||||||
|
// boundary for right padding
|
||||||
// g) Reflect the affine map to have resultant slice reflected
|
// g) Reflect the affine map to have resultant slice reflected
|
||||||
// h) Take the slice and write from left padding size + orignal tensor last dim size
|
// h) Take the slice and write from left padding size + orignal
|
||||||
|
// tensor last dim size
|
||||||
// into result tensor
|
// into result tensor
|
||||||
// Uses the ideas/code used for AtenReflectionPad2dOp
|
// Uses the ideas/code used for AtenReflectionPad2dOp
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -138,7 +148,7 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
SmallVector<int64_t> padInts;
|
SmallVector<int64_t> padInts;
|
||||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -158,55 +168,68 @@ public:
|
||||||
return rewriter.create<arith::SubIOp>(loc, x, y);
|
return rewriter.create<arith::SubIOp>(loc, x, y);
|
||||||
};
|
};
|
||||||
|
|
||||||
enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2};
|
enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 };
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Type indexType = rewriter.getIndexType();
|
Type indexType = rewriter.getIndexType();
|
||||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||||
auto outputType = llvm::cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
|
auto outputType = llvm::cast<RankedTensorType>(
|
||||||
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
unsigned numDims = inputType.getRank();
|
unsigned numDims = inputType.getRank();
|
||||||
assert(numDims >= 2 && "Not enough input dimensions");
|
assert(numDims >= 2 && "Not enough input dimensions");
|
||||||
int64_t lastDim = numDims - 1;
|
int64_t lastDim = numDims - 1;
|
||||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||||
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4
|
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2,
|
||||||
|
// inputShape[2] will give 4
|
||||||
|
|
||||||
Value tileWidth[3], extractOffset[3], insertOffset[3];
|
Value tileWidth[3], extractOffset[3], insertOffset[3];
|
||||||
|
|
||||||
tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
|
tileWidth[PAD_LEFT] =
|
||||||
tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
|
getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
|
||||||
|
tileWidth[PAD_RIGHT] =
|
||||||
|
getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
|
||||||
tileWidth[PAD_CENTER] = lastDimSize;
|
tileWidth[PAD_CENTER] = lastDimSize;
|
||||||
|
|
||||||
extractOffset[PAD_LEFT] = one;
|
extractOffset[PAD_LEFT] = one;
|
||||||
// for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right
|
// The offset for the right hand padding "bar" is:
|
||||||
// lasDimSize - (tileWidth[PAD_RIGHT] + one)
|
// [right] lastDimSize - (tileWidth[PAD_RIGHT] + one)
|
||||||
extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
|
extractOffset[PAD_RIGHT] =
|
||||||
|
createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
|
||||||
extractOffset[PAD_CENTER] = zero;
|
extractOffset[PAD_CENTER] = zero;
|
||||||
|
|
||||||
insertOffset[PAD_LEFT] = zero;
|
insertOffset[PAD_LEFT] = zero;
|
||||||
insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]);
|
insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]);
|
||||||
insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
|
insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
|
||||||
|
|
||||||
|
|
||||||
SmallVector<Value> resultShape{inputShape};
|
SmallVector<Value> resultShape{inputShape};
|
||||||
// Result's last dimension will have shape lastDimSize + left padding size + right padding size
|
// Result's last dimension will have size:
|
||||||
resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
|
// lastDimSize + left padding size + right padding size
|
||||||
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType());
|
resultShape[lastDim] =
|
||||||
|
createIAdd(resultShape[lastDim],
|
||||||
|
createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
|
||||||
|
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
|
||||||
|
inputType.getElementType());
|
||||||
|
|
||||||
// Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor
|
// Helper to reflect/reverse the i-th dimension of an affine map without
|
||||||
// for which the corresponding dimension has a statically known size
|
// symbols. This only works if applied on a tensor for which the
|
||||||
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) {
|
// corresponding dimension has a statically known size
|
||||||
|
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
|
||||||
|
int64_t size) {
|
||||||
AffineExpr d = map.getResult(i);
|
AffineExpr d = map.getResult(i);
|
||||||
return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3
|
return map.replace(d, size - d - 1, numDims,
|
||||||
|
0); // left reflect for (3,1) on input shape (1,2,4).
|
||||||
|
// size = 3, lastDim=2, numDims=3
|
||||||
};
|
};
|
||||||
|
|
||||||
SmallVector<utils::IteratorType> iteratorTypes{numDims, utils::IteratorType::parallel};
|
SmallVector<utils::IteratorType> iteratorTypes{
|
||||||
|
numDims, utils::IteratorType::parallel};
|
||||||
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||||
SmallVector<Value> allOneStrides(numDims, one);
|
SmallVector<Value> allOneStrides(numDims, one);
|
||||||
|
|
||||||
auto addTileToResult = [&](PadLocation padPosition) {
|
auto addTileToResult = [&](PadLocation padPosition) {
|
||||||
// Create the tile by extracting a slice from the input tensor.
|
// Create the tile by extracting a slice from the input tensor.
|
||||||
SmallVector<Value> extractShape{inputShape};
|
SmallVector<Value> extractShape{inputShape};
|
||||||
extractShape[lastDim] = tileWidth[padPosition];
|
extractShape[lastDim] = tileWidth[padPosition];
|
||||||
SmallVector<Value> extractOffsets(numDims, zero);
|
SmallVector<Value> extractOffsets(numDims, zero);
|
||||||
|
@ -214,35 +237,39 @@ public:
|
||||||
Value tile = rewriter.create<tensor::ExtractSliceOp>(
|
Value tile = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
loc, input, extractOffsets, extractShape, allOneStrides);
|
loc, input, extractOffsets, extractShape, allOneStrides);
|
||||||
|
|
||||||
|
|
||||||
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
|
||||||
// Setup the affine map function to resverse the tile along the horizontal for left and right slices
|
// Setup the affine map function to resverse the tile along the horizontal
|
||||||
if(padPosition < PAD_CENTER) {
|
// for left and right slices
|
||||||
inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]);
|
if (padPosition < PAD_CENTER) {
|
||||||
// Take reflected slice as per inputMap
|
inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]);
|
||||||
tile = rewriter.create<linalg::GenericOp>(loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
|
// Take reflected slice as per inputMap
|
||||||
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
|
tile = rewriter
|
||||||
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
.create<linalg::GenericOp>(
|
||||||
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
|
||||||
}).getResult(0);
|
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
|
||||||
|
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
|
||||||
|
b.create<linalg::YieldOp>(nestedLoc, args[0]);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
}
|
}
|
||||||
// Insert the tile in the resultTensor
|
// Insert the tile in the resultTensor
|
||||||
SmallVector<Value> insertOffsets(numDims, zero);
|
SmallVector<Value> insertOffsets(numDims, zero);
|
||||||
insertOffsets[lastDim] = insertOffset[padPosition];
|
insertOffsets[lastDim] = insertOffset[padPosition];
|
||||||
resultTensor = rewriter.create<tensor::InsertSliceOp>(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
|
resultTensor = rewriter.create<tensor::InsertSliceOp>(
|
||||||
|
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
|
||||||
};
|
};
|
||||||
|
|
||||||
if(padInts[PAD_LEFT] > 0)
|
if (padInts[PAD_LEFT] > 0)
|
||||||
addTileToResult(PAD_LEFT);
|
addTileToResult(PAD_LEFT);
|
||||||
if(padInts[PAD_RIGHT] > 0)
|
if (padInts[PAD_RIGHT] > 0)
|
||||||
addTileToResult(PAD_RIGHT);
|
addTileToResult(PAD_RIGHT);
|
||||||
addTileToResult(PAD_CENTER);
|
addTileToResult(PAD_CENTER);
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputType, resultTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
|
@ -79,7 +79,8 @@ public:
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||||
return op.emitError("unimplemented: dim is not constant");
|
return op.emitError("unimplemented: dim is not constant");
|
||||||
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
int64_t inputRank =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||||
dim = toPositiveDim(dim, inputRank);
|
dim = toPositiveDim(dim, inputRank);
|
||||||
if (!isValidDim(dim, inputRank))
|
if (!isValidDim(dim, inputRank))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
@ -248,9 +249,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
|
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(op,
|
||||||
op,
|
"Unimplemented: Mean and Max mode are "
|
||||||
"Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag.");
|
"not supported yet for EmbeddingBag.");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isSparse;
|
bool isSparse;
|
||||||
|
@ -291,28 +292,28 @@ public:
|
||||||
SmallVector<AffineExpr> indicesExpr;
|
SmallVector<AffineExpr> indicesExpr;
|
||||||
indicesExpr.push_back(mlir::getAffineDimExpr(1, context));
|
indicesExpr.push_back(mlir::getAffineDimExpr(1, context));
|
||||||
auto indicesIndexingMap =
|
auto indicesIndexingMap =
|
||||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||||
indicesExpr, context);
|
indicesExpr, context);
|
||||||
|
|
||||||
SmallVector<AffineExpr> offsetsExpr;
|
SmallVector<AffineExpr> offsetsExpr;
|
||||||
offsetsExpr.push_back(mlir::getAffineDimExpr(0, context));
|
offsetsExpr.push_back(mlir::getAffineDimExpr(0, context));
|
||||||
|
|
||||||
auto offsetIndexingMap =
|
auto offsetIndexingMap =
|
||||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||||
offsetsExpr, context);
|
offsetsExpr, context);
|
||||||
|
|
||||||
SmallVector<AffineExpr> outputExpr;
|
SmallVector<AffineExpr> outputExpr;
|
||||||
outputExpr.push_back(mlir::getAffineDimExpr(0, context));
|
outputExpr.push_back(mlir::getAffineDimExpr(0, context));
|
||||||
outputExpr.push_back(mlir::getAffineDimExpr(2, context));
|
outputExpr.push_back(mlir::getAffineDimExpr(2, context));
|
||||||
|
|
||||||
auto outputIndexingMap =
|
auto outputIndexingMap =
|
||||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||||
outputExpr, context);
|
outputExpr, context);
|
||||||
|
|
||||||
SmallVector<AffineMap, 3> indexingMaps = {
|
SmallVector<AffineMap, 3> indexingMaps = {
|
||||||
indicesIndexingMap,
|
indicesIndexingMap,
|
||||||
offsetIndexingMap,
|
offsetIndexingMap,
|
||||||
outputIndexingMap,
|
outputIndexingMap,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Reduce along the indices dim
|
// Reduce along the indices dim
|
||||||
|
@ -326,15 +327,15 @@ public:
|
||||||
Value indicesLength;
|
Value indicesLength;
|
||||||
if (!discardLastOffset) {
|
if (!discardLastOffset) {
|
||||||
SmallVector<Value> sizes{getDimOp(rewriter, loc, offsets, 0),
|
SmallVector<Value> sizes{getDimOp(rewriter, loc, offsets, 0),
|
||||||
embeddingDim};
|
embeddingDim};
|
||||||
|
|
||||||
initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy);
|
initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy);
|
||||||
offsetsLength = getDimOp(rewriter, loc, offsets, 0);
|
offsetsLength = getDimOp(rewriter, loc, offsets, 0);
|
||||||
indicesLength = getDimOp(rewriter, loc, indices, 0);
|
indicesLength = getDimOp(rewriter, loc, indices, 0);
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: include last offset is not yet "
|
op, "Unimplemented: include last offset is not yet "
|
||||||
"supported for EmbeddingBag.");
|
"supported for EmbeddingBag.");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value embeddingBagResult =
|
Value embeddingBagResult =
|
||||||
|
@ -351,10 +352,10 @@ public:
|
||||||
|
|
||||||
Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
|
Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
|
||||||
Value indexIToInt = castIndexToInt64(b, loc, indexI);
|
Value indexIToInt = castIndexToInt64(b, loc, indexI);
|
||||||
Value one = getConstant(
|
Value one =
|
||||||
b, loc, 1,
|
getConstant(b, loc, 1,
|
||||||
mlir::IntegerType::get(getContext(), 64,
|
mlir::IntegerType::get(
|
||||||
IntegerType::Signless));
|
getContext(), 64, IntegerType::Signless));
|
||||||
Value offsetIndexPlusOneInt =
|
Value offsetIndexPlusOneInt =
|
||||||
b.create<arith::AddIOp>(loc, indexIToInt, one);
|
b.create<arith::AddIOp>(loc, indexIToInt, one);
|
||||||
|
|
||||||
|
@ -378,7 +379,7 @@ public:
|
||||||
loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex);
|
loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex);
|
||||||
Value offsetLessThanOrEqualToIndicesIndex =
|
Value offsetLessThanOrEqualToIndicesIndex =
|
||||||
b.create<arith::OrIOp>(loc, offsetLessThanIndicesIndex,
|
b.create<arith::OrIOp>(loc, offsetLessThanIndicesIndex,
|
||||||
offsetEqualToIndicesIndex);
|
offsetEqualToIndicesIndex);
|
||||||
|
|
||||||
Value indicesIndexLessThanNextOffset =
|
Value indicesIndexLessThanNextOffset =
|
||||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||||
|
@ -393,19 +394,18 @@ public:
|
||||||
castIntToIndex(b, loc, indexInIndices));
|
castIntToIndex(b, loc, indexInIndices));
|
||||||
indexIntoWeight.push_back(
|
indexIntoWeight.push_back(
|
||||||
b.create<linalg::IndexOp>(loc, /*value=*/2));
|
b.create<linalg::IndexOp>(loc, /*value=*/2));
|
||||||
Value weightElem = b.create<tensor::ExtractOp>(
|
Value weightElem =
|
||||||
loc, weight, indexIntoWeight);
|
b.create<tensor::ExtractOp>(loc, weight, indexIntoWeight);
|
||||||
|
|
||||||
Value addResult = b.create<arith::AddFOp>(loc, weightElem,
|
Value addResult =
|
||||||
initTensorElem);
|
b.create<arith::AddFOp>(loc, weightElem, initTensorElem);
|
||||||
Value select =
|
Value select = b.create<arith::SelectOp>(
|
||||||
b.create<arith::SelectOp>(loc, indicesIndexWithinBounds,
|
loc, indicesIndexWithinBounds, addResult, initTensorElem);
|
||||||
addResult, initTensorElem);
|
|
||||||
b.create<linalg::YieldOp>(loc, select);
|
b.create<linalg::YieldOp>(loc, select);
|
||||||
})
|
})
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
// cast outputType.
|
// cast outputType.
|
||||||
auto restulType0 = typeConverter->convertType(op->getResult(0).getType());
|
auto restulType0 = typeConverter->convertType(op->getResult(0).getType());
|
||||||
Value castedEmbeddingBagResult =
|
Value castedEmbeddingBagResult =
|
||||||
rewriter.create<tensor::CastOp>(loc, restulType0, embeddingBagResult);
|
rewriter.create<tensor::CastOp>(loc, restulType0, embeddingBagResult);
|
||||||
|
@ -439,7 +439,7 @@ public:
|
||||||
rewriter.create<tensor::CastOp>(loc, resultType3, indicesOut);
|
rewriter.create<tensor::CastOp>(loc, resultType3, indicesOut);
|
||||||
|
|
||||||
rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult,
|
rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult,
|
||||||
castedBagSizeResult, castedMaxIndices});
|
castedBagSizeResult, castedMaxIndices});
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -552,7 +552,8 @@ static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index,
|
||||||
// e.g. x: [2, 3]
|
// e.g. x: [2, 3]
|
||||||
// x[[4], [6, 1]] -> x[6, 4]
|
// x[[4], [6, 1]] -> x[6, 4]
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
|
class ConvertAtenIndexTensorHackedTwinOp
|
||||||
|
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
|
|
|
@ -165,7 +165,8 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
auto selfRank =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
||||||
Value c1 =
|
Value c1 =
|
||||||
|
@ -535,7 +536,8 @@ public:
|
||||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
||||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
|
Type resultElementType =
|
||||||
|
newResultType.cast<RankedTensorType>().getElementType();
|
||||||
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
|
||||||
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
|
||||||
|
|
||||||
|
@ -547,13 +549,15 @@ public:
|
||||||
// Convert the inputs element type equivalent to the result' element type.
|
// Convert the inputs element type equivalent to the result' element type.
|
||||||
if (lhsElementType != rhsElementType) {
|
if (lhsElementType != rhsElementType) {
|
||||||
if (lhsElementType != resultElementType) {
|
if (lhsElementType != resultElementType) {
|
||||||
// True if the lhs element type is not equal to the result' element type.
|
// True if the lhs element type is not equal to the result' element
|
||||||
lhs = torch_to_linalg::convertTensorToElementType(
|
// type.
|
||||||
rewriter, loc, lhs, resultElementType);
|
lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs,
|
||||||
|
resultElementType);
|
||||||
} else {
|
} else {
|
||||||
// True if the rhs element type is not equal to the result' element type.
|
// True if the rhs element type is not equal to the result' element
|
||||||
rhs = torch_to_linalg::convertTensorToElementType(
|
// type.
|
||||||
rewriter, loc, rhs, resultElementType);
|
rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs,
|
||||||
|
resultElementType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -571,7 +575,8 @@ public:
|
||||||
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
|
||||||
|
|
||||||
Value initTensor0 = createZeroInitTensor(
|
Value initTensor0 = createZeroInitTensor(
|
||||||
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
|
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2},
|
||||||
|
resultElementType);
|
||||||
|
|
||||||
Value bmm =
|
Value bmm =
|
||||||
rewriter
|
rewriter
|
||||||
|
@ -634,7 +639,8 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int strides");
|
"only support constant int strides");
|
||||||
SmallVector<int64_t> dilationInts;
|
SmallVector<int64_t> dilationInts;
|
||||||
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
|
if (!matchPattern(op.getDilation(),
|
||||||
|
m_TorchListOfConstantInts(dilationInts)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support constant int dilations");
|
"only support constant int dilations");
|
||||||
|
|
||||||
|
@ -838,8 +844,10 @@ public:
|
||||||
|
|
||||||
Value conv;
|
Value conv;
|
||||||
// the code so far is able to respect all numSpacialDims
|
// the code so far is able to respect all numSpacialDims
|
||||||
// the code below this point is numSpacialDims specific and groupSize specific
|
// the code below this point is numSpacialDims specific and groupSize
|
||||||
// TODO: factor out the above code into a helper function, and then separate convolution into:
|
// specific
|
||||||
|
// TODO: factor out the above code into a helper function, and then separate
|
||||||
|
// convolution into:
|
||||||
// - grouped 1d-3d
|
// - grouped 1d-3d
|
||||||
// - ungrouped 1d-3d
|
// - ungrouped 1d-3d
|
||||||
if (groupSize == 1) {
|
if (groupSize == 1) {
|
||||||
|
@ -854,20 +862,20 @@ public:
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
conv =
|
conv = rewriter
|
||||||
rewriter
|
.create<linalg::Conv2DNchwFchwOp>(
|
||||||
.create<linalg::Conv2DNchwFchwOp>(
|
loc, outputTensor.getType(),
|
||||||
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
|
ValueRange{paddedInput, weight}, outputTensor,
|
||||||
outputTensor, stridesAttr, dilationAttr)
|
stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
conv =
|
conv = rewriter
|
||||||
rewriter
|
.create<linalg::Conv3DNcdhwFcdhwOp>(
|
||||||
.create<linalg::Conv3DNcdhwFcdhwOp>(
|
loc, outputTensor.getType(),
|
||||||
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
|
ValueRange{paddedInput, weight}, outputTensor,
|
||||||
outputTensor, stridesAttr, dilationAttr)
|
stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -877,7 +885,7 @@ public:
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
if(numSpacialDims != 2)
|
if (numSpacialDims != 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only 2D grouped convolution supported");
|
op, "unimplemented: only 2D grouped convolution supported");
|
||||||
|
|
||||||
|
@ -901,11 +909,11 @@ public:
|
||||||
loc, collapsedType, weight, collapsedDims);
|
loc, collapsedType, weight, collapsedDims);
|
||||||
|
|
||||||
conv = rewriter
|
conv = rewriter
|
||||||
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
.create<linalg::DepthwiseConv2DNchwChwOp>(
|
||||||
loc, outputTensor.getType(),
|
loc, outputTensor.getType(),
|
||||||
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
ValueRange{paddedInput, collapsedWeight}, outputTensor,
|
||||||
stridesAttr, dilationAttr)
|
stridesAttr, dilationAttr)
|
||||||
.getResult(0);
|
.getResult(0);
|
||||||
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
||||||
|
@ -979,7 +987,7 @@ public:
|
||||||
conv = rewriter.create<tensor::CollapseShapeOp>(
|
conv = rewriter.create<tensor::CollapseShapeOp>(
|
||||||
loc, outputTensor.getType(), conv,
|
loc, outputTensor.getType(), conv,
|
||||||
expandOutputTensor.getReassociationIndices());
|
expandOutputTensor.getReassociationIndices());
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -194,7 +194,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
|
void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
|
|
@ -100,11 +100,11 @@ public:
|
||||||
if (integerTy.isUnsigned())
|
if (integerTy.isUnsigned())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, opName + " to linalg.* requires input element type "
|
op, opName + " to linalg.* requires input element type "
|
||||||
"to be signed in case of integer");
|
"to be signed in case of integer");
|
||||||
} else {
|
} else {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, opName + " to linalg.* requires Float or Integer "
|
op, opName + " to linalg.* requires Float or Integer "
|
||||||
"input element type");
|
"input element type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -144,8 +144,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Value filledTensorVal =
|
Value filledTensorVal =
|
||||||
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal)
|
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
|
||||||
.result();
|
|
||||||
|
|
||||||
// Create the affine expressions that will be used to
|
// Create the affine expressions that will be used to
|
||||||
// iterate over the input and output tensors.
|
// iterate over the input and output tensors.
|
||||||
|
@ -186,7 +185,7 @@ public:
|
||||||
|
|
||||||
Value resultVal, predicate;
|
Value resultVal, predicate;
|
||||||
if (inElementType.isa<mlir::FloatType>()) {
|
if (inElementType.isa<mlir::FloatType>()) {
|
||||||
arith::CmpFPredicate predType;
|
arith::CmpFPredicate predType;
|
||||||
if (isMax) {
|
if (isMax) {
|
||||||
predType = arith::CmpFPredicate::OGT;
|
predType = arith::CmpFPredicate::OGT;
|
||||||
resultVal = rewriter.create<arith::MaximumFOp>(
|
resultVal = rewriter.create<arith::MaximumFOp>(
|
||||||
|
@ -198,7 +197,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
|
predicate = rewriter.create<arith::CmpFOp>(nestedLoc, predType,
|
||||||
newValue, oldValue);
|
newValue, oldValue);
|
||||||
} else {
|
} else {
|
||||||
arith::CmpIPredicate predType;
|
arith::CmpIPredicate predType;
|
||||||
if (isMax) {
|
if (isMax) {
|
||||||
|
@ -220,8 +219,8 @@ public:
|
||||||
});
|
});
|
||||||
|
|
||||||
// This cast is required to fix the shape in the case of keepDim=True
|
// This cast is required to fix the shape in the case of keepDim=True
|
||||||
Value valuesCast = rewriter.create<tensor::CastOp>(
|
Value valuesCast = rewriter.create<tensor::CastOp>(loc, valResultType,
|
||||||
loc, valResultType, linalgOp.getResult(0));
|
linalgOp.getResult(0));
|
||||||
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
|
||||||
linalgOp.getResult(1));
|
linalgOp.getResult(1));
|
||||||
rewriter.replaceOp(op, {valuesCast, idxCast});
|
rewriter.replaceOp(op, {valuesCast, idxCast});
|
||||||
|
@ -345,7 +344,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
|
||||||
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
|
||||||
auto abs = b.create<math::AbsFOp>(loc, self);
|
auto abs = b.create<math::AbsFOp>(loc, self);
|
||||||
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
|
||||||
Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
|
Value ord =
|
||||||
|
convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
|
||||||
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
auto pow = b.create<math::PowFOp>(loc, abs, ord);
|
||||||
return b.create<arith::AddFOp>(loc, pow, result);
|
return b.create<arith::AddFOp>(loc, pow, result);
|
||||||
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
|
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
|
||||||
|
@ -427,8 +427,8 @@ private:
|
||||||
opInfo.tensorOperand = operands[0];
|
opInfo.tensorOperand = operands[0];
|
||||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
// `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the
|
// `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the
|
||||||
// input tensor.
|
// dimensions of the input tensor.
|
||||||
for (int64_t i = 0; i < inputType.getRank(); i++)
|
for (int64_t i = 0; i < inputType.getRank(); i++)
|
||||||
opInfo.dimSet.insert(i);
|
opInfo.dimSet.insert(i);
|
||||||
|
|
||||||
|
|
|
@ -83,209 +83,224 @@ public:
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Lower aten.replication_pad2d operator into a sequence of
|
// Lower aten.replication_pad2d operator into a sequence of
|
||||||
// tensor.extract_slice and tensor.concat operations.
|
// tensor.extract_slice and tensor.concat operations.
|
||||||
|
|
||||||
class ConvertAtenReplicationPad2dOp
|
class ConvertAtenReplicationPad2dOp
|
||||||
: public OpConversionPattern<AtenReplicationPad2dOp> {
|
: public OpConversionPattern<AtenReplicationPad2dOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor,
|
matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
auto inputType = llvm::cast<RankedTensorType>(input.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
unsigned numDims = inputType.getRank();
|
unsigned numDims = inputType.getRank();
|
||||||
assert(numDims >= 2 && "Not enough input dimensions");
|
assert(numDims >= 2 && "Not enough input dimensions");
|
||||||
|
|
||||||
SmallVector<int64_t> padInts;
|
SmallVector<int64_t> padInts;
|
||||||
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support constant int pad ranges");
|
op, "only support constant int pad ranges");
|
||||||
uint64_t padRank = padInts.size() / 2;
|
uint64_t padRank = padInts.size() / 2;
|
||||||
if (padRank * 2 != padInts.size())
|
if (padRank * 2 != padInts.size())
|
||||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||||
if (inputRank < 0 || padRank > (uint64_t)inputRank)
|
if (inputRank < 0 || padRank > (uint64_t)inputRank)
|
||||||
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank");
|
||||||
|
|
||||||
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
|
||||||
int64_t hDim = numDims - 1;
|
int64_t hDim = numDims - 1;
|
||||||
int64_t vDim = numDims - 2;
|
int64_t vDim = numDims - 2;
|
||||||
Value hDimSize = inputShape[hDim];
|
Value hDimSize = inputShape[hDim];
|
||||||
Value vDimSize = inputShape[vDim];
|
Value vDimSize = inputShape[vDim];
|
||||||
|
|
||||||
enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 };
|
enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 };
|
||||||
enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, };
|
enum tileVLoc {
|
||||||
// vTile denotes the vertical size of the tile
|
TOP = 0,
|
||||||
// hTile denotes the horizontal size of the tile
|
VCENTER = 2,
|
||||||
// The padding results are composed of following tiles:
|
BOTTOM = 1,
|
||||||
// vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT]
|
};
|
||||||
// vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], vTile[VCENTER]hTile[RIGHT]
|
// vTile denotes the vertical size of the tile
|
||||||
// vTile[BOTTOM]hTile[LEFT], vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT]
|
// hTile denotes the horizontal size of the tile
|
||||||
// vTile[VCENTER]hTile[HCENTER] is the original input tensor
|
// The padding results are composed of following tiles:
|
||||||
Type indexType = rewriter.getIndexType();
|
// vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT]
|
||||||
Value vTile[3];
|
// vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER],
|
||||||
Value hTile[3];
|
// vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT],
|
||||||
vTile[VCENTER] = vDimSize;
|
// vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT]
|
||||||
hTile[HCENTER] = hDimSize;
|
// vTile[VCENTER]hTile[HCENTER] is the original input tensor
|
||||||
vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType);
|
Type indexType = rewriter.getIndexType();
|
||||||
vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType);
|
Value vTile[3];
|
||||||
hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType);
|
Value hTile[3];
|
||||||
hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType);
|
vTile[VCENTER] = vDimSize;
|
||||||
|
hTile[HCENTER] = hDimSize;
|
||||||
|
vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType);
|
||||||
|
vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType);
|
||||||
|
hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType);
|
||||||
|
hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType);
|
||||||
|
|
||||||
bool hasLeftPadding = false;
|
bool hasLeftPadding = false;
|
||||||
bool hasRightPadding = false;
|
bool hasRightPadding = false;
|
||||||
bool hasTopPadding = false;
|
bool hasTopPadding = false;
|
||||||
bool hasBottomPadding = false;
|
bool hasBottomPadding = false;
|
||||||
|
|
||||||
for (auto i: {TOP, VCENTER, BOTTOM}){
|
for (auto i : {TOP, VCENTER, BOTTOM}) {
|
||||||
for (auto j: {LEFT, HCENTER, RIGHT}) {
|
for (auto j : {LEFT, HCENTER, RIGHT}) {
|
||||||
auto constVtile{
|
auto constVtile{
|
||||||
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
|
mlir::dyn_cast<mlir::arith::ConstantOp>(vTile[i].getDefiningOp())
|
||||||
.getValue()
|
.getValue()
|
||||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
||||||
|
|
||||||
auto constHtile{
|
auto constHtile{
|
||||||
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
|
mlir::dyn_cast<mlir::arith::ConstantOp>(hTile[j].getDefiningOp())
|
||||||
.getValue()
|
.getValue()
|
||||||
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
.dyn_cast_or_null<mlir::IntegerAttr>()};
|
||||||
auto vSize = constVtile.getInt();
|
auto vSize = constVtile.getInt();
|
||||||
auto hSize = constHtile.getInt();
|
auto hSize = constHtile.getInt();
|
||||||
|
|
||||||
if ((i == TOP) && (vSize > 0))
|
if ((i == TOP) && (vSize > 0))
|
||||||
hasTopPadding = true;
|
hasTopPadding = true;
|
||||||
if ((i == BOTTOM) && (vSize > 0))
|
if ((i == BOTTOM) && (vSize > 0))
|
||||||
hasBottomPadding = true;
|
hasBottomPadding = true;
|
||||||
if ((j == LEFT) && (hSize > 0))
|
if ((j == LEFT) && (hSize > 0))
|
||||||
hasLeftPadding = true;
|
hasLeftPadding = true;
|
||||||
if ((j == RIGHT) && (hSize > 0))
|
if ((j == RIGHT) && (hSize > 0))
|
||||||
hasRightPadding = true;
|
hasRightPadding = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto createSub = [&](Value x, Value y) {
|
|
||||||
return rewriter.create<arith::SubIOp>(loc, x, y);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Extract left and right pad tiles.
|
|
||||||
Value zero = getConstant(rewriter, loc, 0, indexType);
|
|
||||||
Value one = getConstant(rewriter, loc, 1, indexType);
|
|
||||||
Value hDimSizeMinusOne = createSub(hDimSize, one);
|
|
||||||
Value vDimSizeMinusOne = createSub(vDimSize, one);
|
|
||||||
SmallVector<Value> allOneStrides(numDims, one);
|
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsLT(numDims, zero);
|
|
||||||
extractOffsetsLT[hDim] = zero;
|
|
||||||
extractOffsetsLT[vDim] = zero;
|
|
||||||
SmallVector<Value> extractShapeLR(numDims, one);
|
|
||||||
extractShapeLR[hDim] = one;
|
|
||||||
extractShapeLR[vDim] = vDimSize;
|
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsRight(numDims, zero);
|
|
||||||
extractOffsetsRight[hDim] = hDimSizeMinusOne;
|
|
||||||
extractOffsetsRight[vDim] = zero;
|
|
||||||
|
|
||||||
SmallVector<Value> extractOffsetsBottom(numDims, zero);
|
|
||||||
extractOffsetsBottom[hDim] = zero;
|
|
||||||
extractOffsetsBottom[vDim] = vDimSizeMinusOne;
|
|
||||||
|
|
||||||
SmallVector<Value> extractShapeTB(numDims, one);
|
|
||||||
extractShapeTB[hDim] = hDimSize;
|
|
||||||
extractShapeTB[vDim] = one;
|
|
||||||
|
|
||||||
SmallVector<Value> tensorsLeft;
|
|
||||||
SmallVector<Value> tensorsRight;
|
|
||||||
SmallVector<Value> tensorsCenter;
|
|
||||||
Value centerTile;
|
|
||||||
SmallVector<Value> tensorsRes;
|
|
||||||
|
|
||||||
if (hasLeftPadding) {
|
|
||||||
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
|
|
||||||
Value vLeftSlice = vCenterLeftSlice;
|
|
||||||
if (hasTopPadding) {
|
|
||||||
Value topLeftValue = rewriter.create<tensor::ExtractOp>(
|
|
||||||
loc, input, ValueRange{zero, zero, zero, zero});
|
|
||||||
//pad vCenterLeftSlice on the top
|
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
|
||||||
lowPadding[2] = padInts[2];
|
|
||||||
vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
|
|
||||||
}
|
|
||||||
if (hasBottomPadding) {
|
|
||||||
Value bottomLeftValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
|
|
||||||
|
|
||||||
//pad vLeftSlice at the bottom
|
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
|
||||||
highPadding[2] = padInts[3];
|
|
||||||
vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
|
|
||||||
}
|
|
||||||
for (auto i=0; i<padInts[0]; ++i) {
|
|
||||||
tensorsLeft.push_back(vLeftSlice);
|
|
||||||
}
|
|
||||||
Value leftPadTile =
|
|
||||||
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft);
|
|
||||||
tensorsRes.push_back(leftPadTile);
|
|
||||||
}
|
|
||||||
if (hasTopPadding) {
|
|
||||||
Value topHcenterSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
loc, input, extractOffsetsLT, extractShapeTB, allOneStrides);
|
|
||||||
for (auto i = 0; i < padInts[2]; ++i) {
|
|
||||||
tensorsCenter.push_back(topHcenterSlice);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tensorsCenter.push_back(input);
|
|
||||||
if (hasBottomPadding) {
|
|
||||||
Value bottomHcenterSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides);
|
|
||||||
for (auto i = 0; i < padInts[3]; ++i) {
|
|
||||||
tensorsCenter.push_back(bottomHcenterSlice);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter);
|
|
||||||
tensorsRes.push_back(centerTile);
|
|
||||||
|
|
||||||
if (hasRightPadding) {
|
|
||||||
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
|
|
||||||
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
|
|
||||||
Value vRightSlice = vCenterRightSlice;
|
|
||||||
if (hasTopPadding) {
|
|
||||||
Value topRightValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
|
|
||||||
|
|
||||||
//pad vCenterRightSlice on the top
|
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
|
||||||
lowPadding[2] = padInts[2];
|
|
||||||
vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
|
|
||||||
}
|
|
||||||
if (hasBottomPadding) {
|
|
||||||
Value bottomRightValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
|
|
||||||
|
|
||||||
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
|
|
||||||
SmallVector<int64_t> lowPadding(4, 0);
|
|
||||||
SmallVector<int64_t> highPadding(4, 0);
|
|
||||||
highPadding[2] = padInts[3];
|
|
||||||
vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue);
|
|
||||||
}
|
|
||||||
for (auto i=0; i<padInts[1]; ++i) {
|
|
||||||
tensorsRight.push_back(vRightSlice);
|
|
||||||
}
|
|
||||||
Value rightPadTile = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
|
|
||||||
tensorsRes.push_back(rightPadTile);
|
|
||||||
}
|
|
||||||
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
|
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
}
|
auto createSub = [&](Value x, Value y) {
|
||||||
|
return rewriter.create<arith::SubIOp>(loc, x, y);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract left and right pad tiles.
|
||||||
|
Value zero = getConstant(rewriter, loc, 0, indexType);
|
||||||
|
Value one = getConstant(rewriter, loc, 1, indexType);
|
||||||
|
Value hDimSizeMinusOne = createSub(hDimSize, one);
|
||||||
|
Value vDimSizeMinusOne = createSub(vDimSize, one);
|
||||||
|
SmallVector<Value> allOneStrides(numDims, one);
|
||||||
|
|
||||||
|
SmallVector<Value> extractOffsetsLT(numDims, zero);
|
||||||
|
extractOffsetsLT[hDim] = zero;
|
||||||
|
extractOffsetsLT[vDim] = zero;
|
||||||
|
SmallVector<Value> extractShapeLR(numDims, one);
|
||||||
|
extractShapeLR[hDim] = one;
|
||||||
|
extractShapeLR[vDim] = vDimSize;
|
||||||
|
|
||||||
|
SmallVector<Value> extractOffsetsRight(numDims, zero);
|
||||||
|
extractOffsetsRight[hDim] = hDimSizeMinusOne;
|
||||||
|
extractOffsetsRight[vDim] = zero;
|
||||||
|
|
||||||
|
SmallVector<Value> extractOffsetsBottom(numDims, zero);
|
||||||
|
extractOffsetsBottom[hDim] = zero;
|
||||||
|
extractOffsetsBottom[vDim] = vDimSizeMinusOne;
|
||||||
|
|
||||||
|
SmallVector<Value> extractShapeTB(numDims, one);
|
||||||
|
extractShapeTB[hDim] = hDimSize;
|
||||||
|
extractShapeTB[vDim] = one;
|
||||||
|
|
||||||
|
SmallVector<Value> tensorsLeft;
|
||||||
|
SmallVector<Value> tensorsRight;
|
||||||
|
SmallVector<Value> tensorsCenter;
|
||||||
|
Value centerTile;
|
||||||
|
SmallVector<Value> tensorsRes;
|
||||||
|
|
||||||
|
if (hasLeftPadding) {
|
||||||
|
Value vCenterLeftSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
|
loc, input, extractOffsetsLT, extractShapeLR, allOneStrides);
|
||||||
|
Value vLeftSlice = vCenterLeftSlice;
|
||||||
|
if (hasTopPadding) {
|
||||||
|
Value topLeftValue = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, input, ValueRange{zero, zero, zero, zero});
|
||||||
|
// pad vCenterLeftSlice on the top
|
||||||
|
SmallVector<int64_t> lowPadding(4, 0);
|
||||||
|
SmallVector<int64_t> highPadding(4, 0);
|
||||||
|
lowPadding[2] = padInts[2];
|
||||||
|
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||||
|
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
|
||||||
|
}
|
||||||
|
if (hasBottomPadding) {
|
||||||
|
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
|
||||||
|
|
||||||
|
// pad vLeftSlice at the bottom
|
||||||
|
SmallVector<int64_t> lowPadding(4, 0);
|
||||||
|
SmallVector<int64_t> highPadding(4, 0);
|
||||||
|
highPadding[2] = padInts[3];
|
||||||
|
vLeftSlice = torch_to_linalg::getPaddedTensor(
|
||||||
|
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < padInts[0]; ++i) {
|
||||||
|
tensorsLeft.push_back(vLeftSlice);
|
||||||
|
}
|
||||||
|
Value leftPadTile =
|
||||||
|
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsLeft);
|
||||||
|
tensorsRes.push_back(leftPadTile);
|
||||||
|
}
|
||||||
|
if (hasTopPadding) {
|
||||||
|
Value topHcenterSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
|
loc, input, extractOffsetsLT, extractShapeTB, allOneStrides);
|
||||||
|
for (auto i = 0; i < padInts[2]; ++i) {
|
||||||
|
tensorsCenter.push_back(topHcenterSlice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensorsCenter.push_back(input);
|
||||||
|
if (hasBottomPadding) {
|
||||||
|
Value bottomHcenterSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
|
loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides);
|
||||||
|
for (auto i = 0; i < padInts[3]; ++i) {
|
||||||
|
tensorsCenter.push_back(bottomHcenterSlice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
centerTile = rewriter.create<tensor::ConcatOp>(loc, 2, tensorsCenter);
|
||||||
|
tensorsRes.push_back(centerTile);
|
||||||
|
|
||||||
|
if (hasRightPadding) {
|
||||||
|
Value vCenterRightSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
|
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
|
||||||
|
Value vRightSlice = vCenterRightSlice;
|
||||||
|
if (hasTopPadding) {
|
||||||
|
Value topRightValue = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
|
||||||
|
|
||||||
|
// pad vCenterRightSlice on the top
|
||||||
|
SmallVector<int64_t> lowPadding(4, 0);
|
||||||
|
SmallVector<int64_t> highPadding(4, 0);
|
||||||
|
lowPadding[2] = padInts[2];
|
||||||
|
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||||
|
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
|
||||||
|
}
|
||||||
|
if (hasBottomPadding) {
|
||||||
|
Value bottomRightValue = rewriter.create<tensor::ExtractOp>(
|
||||||
|
loc, input,
|
||||||
|
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
|
||||||
|
|
||||||
|
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
|
||||||
|
SmallVector<int64_t> lowPadding(4, 0);
|
||||||
|
SmallVector<int64_t> highPadding(4, 0);
|
||||||
|
highPadding[2] = padInts[3];
|
||||||
|
vRightSlice = torch_to_linalg::getPaddedTensor(
|
||||||
|
op, rewriter, vRightSlice, lowPadding, highPadding,
|
||||||
|
bottomRightValue);
|
||||||
|
}
|
||||||
|
for (auto i = 0; i < padInts[1]; ++i) {
|
||||||
|
tensorsRight.push_back(vRightSlice);
|
||||||
|
}
|
||||||
|
Value rightPadTile =
|
||||||
|
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
|
||||||
|
tensorsRes.push_back(rightPadTile);
|
||||||
|
}
|
||||||
|
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
|
||||||
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, resTensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Converts constant tensor allocation like ops.
|
// Converts constant tensor allocation like ops.
|
||||||
|
@ -348,8 +363,8 @@ public:
|
||||||
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
// Create an uninitialized tensor of `resultSize` shape and fill it with
|
||||||
// value `fillVal`.
|
// value `fillVal`.
|
||||||
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
|
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
|
||||||
Value outputTensor =
|
Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
|
||||||
createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal);
|
resultElementType, constVal);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -384,7 +399,8 @@ public:
|
||||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
if (!matchPattern(op.getMemoryFormat(),
|
||||||
|
m_TorchConstantInt(&memoryFormat)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: the memory format should be specified in "
|
op, "unimplemented: the memory format should be specified in "
|
||||||
"an integer constant");
|
"an integer constant");
|
||||||
|
@ -495,7 +511,8 @@ public:
|
||||||
typeConverter->convertType(op->getResult(0).getType())
|
typeConverter->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
Type dtype = resultType.getElementType();
|
Type dtype = resultType.getElementType();
|
||||||
Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
|
Value start =
|
||||||
|
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
|
||||||
Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype);
|
Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype);
|
||||||
Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype);
|
Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype);
|
||||||
|
|
||||||
|
|
|
@ -426,10 +426,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (isa<AtenAbsOp>(op))
|
if (isa<AtenAbsOp>(op))
|
||||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenIsinfOp>(op)){
|
if (isa<AtenIsinfOp>(op)) {
|
||||||
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||||
Value infinity = b.create<arith::ConstantOp>(
|
Value infinity = b.create<arith::ConstantOp>(
|
||||||
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
|
loc,
|
||||||
|
b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
|
||||||
return createEqual(b, loc, abs.getType(), abs, infinity);
|
return createEqual(b, loc, abs.getType(), abs, infinity);
|
||||||
}
|
}
|
||||||
if (isa<AtenSigmoidOp>(op)) {
|
if (isa<AtenSigmoidOp>(op)) {
|
||||||
|
|
|
@ -7,13 +7,13 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||||
#include "../PassDetail.h"
|
#include "../PassDetail.h"
|
||||||
#include "PopulatePatterns.h"
|
#include "PopulatePatterns.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
|
|
|
@ -923,8 +923,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
||||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
||||||
Type resultType =
|
Type resultType =
|
||||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
auto result =
|
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
||||||
rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
|
||||||
|
|
||||||
rewriter.replaceOp(
|
rewriter.replaceOp(
|
||||||
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
|
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
|
||||||
|
@ -1797,8 +1796,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
||||||
|
|
||||||
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
|
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
|
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, context)
|
||||||
context)
|
|
||||||
|
|
||||||
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
|
||||||
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
|
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);
|
||||||
|
|
|
@ -30,8 +30,8 @@ using namespace mlir::torch::torch_to_stablehlo;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
static Value createInitialValueForGatherScatterOp(Operation *op,
|
static Value createInitialValueForGatherScatterOp(Operation *op,
|
||||||
RankedTensorType constType,
|
RankedTensorType constType,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto elementTy = constType.getElementType();
|
auto elementTy = constType.getElementType();
|
||||||
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
if (isa<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
|
|
|
@ -35,7 +35,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto constType = RankedTensorType::get({}, elementTy);
|
auto constType = RankedTensorType::get({}, elementTy);
|
||||||
// Avg pooling
|
// Avg pooling
|
||||||
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
|
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
|
||||||
|
AtenCumsumOp>(op)) {
|
||||||
if (elementTy.isa<mlir::FloatType>()) {
|
if (elementTy.isa<mlir::FloatType>()) {
|
||||||
auto constAttr = DenseElementsAttr::get(
|
auto constAttr = DenseElementsAttr::get(
|
||||||
constType, {APFloat::getZero(
|
constType, {APFloat::getZero(
|
||||||
|
@ -373,7 +374,6 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
template <typename AtenOpT, int Dim>
|
template <typename AtenOpT, int Dim>
|
||||||
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
|
||||||
|
@ -388,45 +388,45 @@ public:
|
||||||
Type inputElemTy = inputTy.getElementType();
|
Type inputElemTy = inputTy.getElementType();
|
||||||
int64_t inputRank = inputTy.getRank();
|
int64_t inputRank = inputTy.getRank();
|
||||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
auto outShape = outTy.getShape();
|
auto outShape = outTy.getShape();
|
||||||
|
|
||||||
|
|
||||||
if (inputRank <= Dim) {
|
if (inputRank <= Dim) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
|
||||||
}
|
}
|
||||||
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
SmallVector<int64_t, Dim> padding, kernelSize, stride;
|
||||||
bool ceilMode = false;
|
bool ceilMode = false;
|
||||||
bool countIncludePad = true;
|
bool countIncludePad = true;
|
||||||
|
|
||||||
if (!(matchPattern(op.getKernelSize(),
|
if (!(matchPattern(op.getKernelSize(),
|
||||||
m_TorchListOfConstantInts(kernelSize)))) {
|
m_TorchListOfConstantInts(kernelSize)))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "non-const int kernel size unsupported!");
|
op, "non-const int kernel size unsupported!");
|
||||||
}
|
}
|
||||||
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
|
||||||
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"non-const int stride unsupported!");
|
||||||
}
|
}
|
||||||
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"non-const int padding unsupported!");
|
"non-const int padding unsupported!");
|
||||||
}
|
}
|
||||||
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(
|
||||||
"non-const bool ceil_mode unsupported!");
|
op, "non-const bool ceil_mode unsupported!");
|
||||||
}
|
}
|
||||||
if (!(matchPattern(op.getCountIncludePad(),
|
if (!(matchPattern(op.getCountIncludePad(),
|
||||||
m_TorchConstantBool(&countIncludePad)))) {
|
m_TorchConstantBool(&countIncludePad)))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "non-const bool count_include_pad unsupported!");
|
op, "non-const bool count_include_pad unsupported!");
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
if constexpr (std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
|
||||||
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride())))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only None divisor_override supported for now!");
|
op, "only None divisor_override supported for now!");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepend 1 to kernelSize, stride, dilation until they are of same rank
|
// Prepend 1 to kernelSize, stride, dilation until they are of same rank
|
||||||
|
@ -437,33 +437,35 @@ public:
|
||||||
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
SmallVector<int64_t> stablehloPadding(inputRank * 2, 0);
|
||||||
|
|
||||||
std::copy(stride.begin(), stride.end(),
|
std::copy(stride.begin(), stride.end(),
|
||||||
stablehloStride.begin() + inputRank - Dim);
|
stablehloStride.begin() + inputRank - Dim);
|
||||||
std::copy(kernelSize.begin(), kernelSize.end(),
|
std::copy(kernelSize.begin(), kernelSize.end(),
|
||||||
stablehloKernelSize.begin() + inputRank - Dim);
|
stablehloKernelSize.begin() + inputRank - Dim);
|
||||||
if (Dim == 1) {
|
if (Dim == 1) {
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[0];
|
||||||
} else {
|
} else {
|
||||||
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 4] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
stablehloPadding[stablehloPadding.size() - 3] = padding[0];
|
||||||
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 2] = padding[1];
|
||||||
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
Value initVal =
|
||||||
|
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
|
||||||
|
|
||||||
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
|
RankedTensorType::get(
|
||||||
rewriter.getI64Type()),
|
{static_cast<int64_t>(stablehloKernelSize.size())},
|
||||||
|
rewriter.getI64Type()),
|
||||||
stablehloKernelSize);
|
stablehloKernelSize);
|
||||||
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloStride.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
stablehloStride);
|
stablehloStride);
|
||||||
DenseIntElementsAttr baseDilations;
|
DenseIntElementsAttr baseDilations;
|
||||||
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
RankedTensorType::get({static_cast<int64_t>(stablehloDilation.size())},
|
||||||
rewriter.getI64Type()),
|
rewriter.getI64Type()),
|
||||||
stablehloDilation);
|
stablehloDilation);
|
||||||
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
DenseIntElementsAttr pad = DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(
|
RankedTensorType::get(
|
||||||
|
@ -485,31 +487,31 @@ public:
|
||||||
auto secondArg = *sumBlock.args_rbegin();
|
auto secondArg = *sumBlock.args_rbegin();
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&sumBlock);
|
rewriter.setInsertionPointToStart(&sumBlock);
|
||||||
|
|
||||||
Value sumResult =
|
Value sumResult =
|
||||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use kernel size as the divisor
|
// Use kernel size as the divisor
|
||||||
if (countIncludePad) {
|
if (countIncludePad) {
|
||||||
Value divisor;
|
Value divisor;
|
||||||
if (Dim == 1) {
|
if (Dim == 1) {
|
||||||
divisor =
|
divisor =
|
||||||
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
hlo::getConstTensor<int64_t>(rewriter, op, {kernelSize[0]}, {})
|
||||||
.value();
|
.value();
|
||||||
} else {
|
} else {
|
||||||
divisor = hlo::getConstTensor<int64_t>(
|
divisor = hlo::getConstTensor<int64_t>(
|
||||||
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
rewriter, op, {kernelSize[0] * kernelSize[1]}, {})
|
||||||
.value();
|
.value();
|
||||||
}
|
}
|
||||||
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy);
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
rewriter.replaceOpWithNewOp<mlir::chlo::BroadcastDivOp>(
|
||||||
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use another mhlo.ReduceWindowOp to get the divisor
|
// Use another mhlo.ReduceWindowOp to get the divisor
|
||||||
|
@ -518,8 +520,8 @@ public:
|
||||||
windowSizeConst =
|
windowSizeConst =
|
||||||
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
|
||||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
auto inputShapeVec =
|
auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input,
|
||||||
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
|
options.dimSizeIndexBits);
|
||||||
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
|
||||||
op->getLoc(), inputShapeVec);
|
op->getLoc(), inputShapeVec);
|
||||||
|
|
||||||
|
@ -544,23 +546,20 @@ public:
|
||||||
secondArg = *sizeBlock.args_rbegin();
|
secondArg = *sizeBlock.args_rbegin();
|
||||||
|
|
||||||
{
|
{
|
||||||
OpBuilder::InsertionGuard guard(rewriter);
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
rewriter.setInsertionPointToStart(&sizeBlock);
|
rewriter.setInsertionPointToStart(&sizeBlock);
|
||||||
|
|
||||||
Value sumResult =
|
Value sumResult =
|
||||||
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
rewriter.create<stablehlo::AddOp>(op->getLoc(), firstArg, secondArg);
|
||||||
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), sumResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
|
||||||
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
} // namespace
|
||||||
|
|
||||||
|
|
||||||
// AtenCumsumOp
|
// AtenCumsumOp
|
||||||
template <>
|
template <>
|
||||||
|
@ -660,10 +659,10 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
|
||||||
context, options);
|
context, options);
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
target.addIllegalOp<AtenCumsumOp>();
|
||||||
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
|
||||||
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
|
||||||
target.addIllegalOp<AtenOp>(); \
|
target.addIllegalOp<AtenOp>(); \
|
||||||
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>( \
|
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
|
||||||
typeConverter, context, options)
|
options)
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
|
||||||
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
|
||||||
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
#undef INSERT_ATEN_AVGPOOL_PATTERN
|
||||||
|
|
|
@ -16,13 +16,13 @@
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/ChloOps.h"
|
#include "stablehlo/dialect/ChloOps.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "stablehlo/dialect/StablehloOps.h"
|
#include "stablehlo/dialect/StablehloOps.h"
|
||||||
|
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
||||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
|
@ -22,7 +23,6 @@
|
||||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -403,7 +403,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("dim must be a Scalar constant");
|
return op->emitError("dim must be a Scalar constant");
|
||||||
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
int64_t inputRank =
|
||||||
|
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
||||||
dim = toPositiveDim(dim, inputRank + 1);
|
dim = toPositiveDim(dim, inputRank + 1);
|
||||||
if (!isValidDim(dim, inputRank + 1))
|
if (!isValidDim(dim, inputRank + 1))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
|
|
@ -131,10 +131,10 @@ tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
Operation *op,
|
Operation *op,
|
||||||
Value paramsValue,
|
Value paramsValue,
|
||||||
Value indexValue,
|
Value indexValue,
|
||||||
int32_t axis) {
|
int32_t axis) {
|
||||||
// For easy understanding of this algorithm, the following comments are with
|
// For easy understanding of this algorithm, the following comments are with
|
||||||
// an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2,
|
// an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2,
|
||||||
// !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32>
|
// !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32>
|
||||||
|
@ -210,9 +210,9 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
// Lowers Gather operators to a sequence of TOSA ops.
|
// Lowers Gather operators to a sequence of TOSA ops.
|
||||||
// taken from
|
// taken from
|
||||||
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
|
||||||
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
|
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
Operation *op, Type outType,
|
Type outType, Value paramsValue,
|
||||||
Value paramsValue, Value indicesValue) {
|
Value indicesValue) {
|
||||||
auto resultType = outType.dyn_cast<ShapedType>();
|
auto resultType = outType.dyn_cast<ShapedType>();
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
||||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
||||||
|
@ -683,7 +683,6 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Common function for lowering reduce operations to TOSA ops.
|
// Common function for lowering reduce operations to TOSA ops.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::optional<Value> convertReduceOpCommon(
|
std::optional<Value> convertReduceOpCommon(
|
||||||
|
@ -721,9 +720,8 @@ std::optional<Value> convertReduceOpCommon(
|
||||||
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);
|
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);
|
||||||
|
|
||||||
shape_vec[axis_val] = 1;
|
shape_vec[axis_val] = 1;
|
||||||
RankedTensorType reduce_type = RankedTensorType::get(
|
RankedTensorType reduce_type =
|
||||||
shape_vec,
|
RankedTensorType::get(shape_vec, reduce_element_type);
|
||||||
reduce_element_type);
|
|
||||||
|
|
||||||
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
|
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
|
||||||
val, axis_attr);
|
val, axis_attr);
|
||||||
|
|
|
@ -176,7 +176,8 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
|
||||||
// Default template creates a constant tensor in T.
|
// Default template creates a constant tensor in T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
ArrayRef<T> vec, ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
ArrayRef<T> vec, ArrayRef<int64_t> shape,
|
||||||
|
std::optional<Type> dtype) {
|
||||||
uint64_t num_total_elements = 1;
|
uint64_t num_total_elements = 1;
|
||||||
for (int64_t a : shape) {
|
for (int64_t a : shape) {
|
||||||
num_total_elements *= a;
|
num_total_elements *= a;
|
||||||
|
@ -188,7 +189,7 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
auto width = sizeof(T) * 8;
|
auto width = sizeof(T) * 8;
|
||||||
if constexpr(std::is_same_v<T, bool>)
|
if constexpr (std::is_same_v<T, bool>)
|
||||||
width = 1;
|
width = 1;
|
||||||
|
|
||||||
auto const_type =
|
auto const_type =
|
||||||
|
@ -199,7 +200,7 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
return rewriter.createOrFold<tosa::CastOp>(
|
return rewriter.createOrFold<tosa::CastOp>(
|
||||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||||
}
|
}
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
|
@ -209,7 +210,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
template <>
|
template <>
|
||||||
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<APInt> vec,
|
Operation *op, ArrayRef<APInt> vec,
|
||||||
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
ArrayRef<int64_t> shape,
|
||||||
|
std::optional<Type> dtype) {
|
||||||
uint64_t num_total_elements = 1;
|
uint64_t num_total_elements = 1;
|
||||||
for (int64_t a : shape) {
|
for (int64_t a : shape) {
|
||||||
num_total_elements *= a;
|
num_total_elements *= a;
|
||||||
|
@ -228,7 +230,7 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
return rewriter.createOrFold<tosa::CastOp>(
|
return rewriter.createOrFold<tosa::CastOp>(
|
||||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||||
}
|
}
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
|
@ -238,7 +240,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
|
||||||
template <>
|
template <>
|
||||||
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||||
Operation *op, ArrayRef<float> vec,
|
Operation *op, ArrayRef<float> vec,
|
||||||
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
|
ArrayRef<int64_t> shape,
|
||||||
|
std::optional<Type> dtype) {
|
||||||
uint64_t num_total_elements = 1;
|
uint64_t num_total_elements = 1;
|
||||||
for (int64_t a : shape) {
|
for (int64_t a : shape) {
|
||||||
num_total_elements *= a;
|
num_total_elements *= a;
|
||||||
|
@ -256,7 +259,7 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
|
||||||
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
|
||||||
|
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
return rewriter.createOrFold<tosa::CastOp>(
|
return rewriter.createOrFold<tosa::CastOp>(
|
||||||
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
op->getLoc(), RankedTensorType::get(shape, *dtype), const_op);
|
||||||
}
|
}
|
||||||
return const_op.getResult();
|
return const_op.getResult();
|
||||||
|
@ -347,23 +350,17 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Template instantiation
|
// Template instantiation
|
||||||
template std::optional<Value> getConstTensor<bool>(PatternRewriter &,
|
template std::optional<Value>
|
||||||
Operation *,
|
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
|
||||||
ArrayRef<bool> vec,
|
ArrayRef<int64_t> shape, std::optional<Type> dtype);
|
||||||
ArrayRef<int64_t> shape,
|
|
||||||
std::optional<Type> dtype);
|
|
||||||
|
|
||||||
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
|
template std::optional<Value>
|
||||||
Operation *,
|
getConstTensor<int32_t>(PatternRewriter &, Operation *, ArrayRef<int32_t> vec,
|
||||||
ArrayRef<int32_t> vec,
|
ArrayRef<int64_t> shape, std::optional<Type> dtype);
|
||||||
ArrayRef<int64_t> shape,
|
|
||||||
std::optional<Type> dtype);
|
|
||||||
|
|
||||||
template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
|
template std::optional<Value>
|
||||||
Operation *,
|
getConstTensor<int64_t>(PatternRewriter &, Operation *, ArrayRef<int64_t> vec,
|
||||||
ArrayRef<int64_t> vec,
|
ArrayRef<int64_t> shape, std::optional<Type> dtype);
|
||||||
ArrayRef<int64_t> shape,
|
|
||||||
std::optional<Type> dtype);
|
|
||||||
|
|
||||||
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
|
||||||
TypeAttr &accType) {
|
TypeAttr &accType) {
|
||||||
|
|
|
@ -87,7 +87,8 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter,
|
||||||
ValueRange outputs) {
|
ValueRange outputs) {
|
||||||
SmallVector<Value, 8> newOperands = inputs;
|
SmallVector<Value, 8> newOperands = inputs;
|
||||||
newOperands.append(outputs.begin(), outputs.end());
|
newOperands.append(outputs.begin(), outputs.end());
|
||||||
return cast<TMTensorOp>(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands));
|
return cast<TMTensorOp>(
|
||||||
|
tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generic conversion pattern that matches any TMTensorOp. This avoids template
|
/// Generic conversion pattern that matches any TMTensorOp. This avoids template
|
||||||
|
|
|
@ -157,7 +157,7 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
||||||
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
return builder.create<Torch::ConstantNumberOp>(loc, intValue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type.isa<Torch::BoolType>()) {
|
if (type.isa<Torch::BoolType>()) {
|
||||||
return builder.create<Torch::ConstantBoolOp>(loc,
|
return builder.create<Torch::ConstantBoolOp>(loc,
|
||||||
value.cast<IntegerAttr>());
|
value.cast<IntegerAttr>());
|
||||||
|
|
|
@ -203,8 +203,8 @@ static Value getScalarFloatValue(Value input, Location loc,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||||
auto func =
|
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
|
||||||
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFunctionAttr());
|
*this, getFunctionAttr());
|
||||||
if (!func)
|
if (!func)
|
||||||
return emitError() << "'@" << getFunction()
|
return emitError() << "'@" << getFunction()
|
||||||
<< "' does not reference a valid function";
|
<< "' does not reference a valid function";
|
||||||
|
@ -453,11 +453,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
// If the condition is constant, delete the dead branch and inline the live
|
// If the condition is constant, delete the dead branch and inline the live
|
||||||
// branch.
|
// branch.
|
||||||
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
|
||||||
auto constantBool = op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
|
auto constantBool =
|
||||||
|
op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
|
||||||
if (!constantBool)
|
if (!constantBool)
|
||||||
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
return rewriter.notifyMatchFailure(op, "non-constant condition");
|
||||||
replaceOpWithRegion(
|
replaceOpWithRegion(rewriter, op,
|
||||||
rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion());
|
constantBool.getValue() ? op.getThenRegion()
|
||||||
|
: op.getElseRegion());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
// If the thenRegion and elseRegion yield the same Value's, then use those
|
// If the thenRegion and elseRegion yield the same Value's, then use those
|
||||||
|
@ -515,14 +517,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
continue;
|
continue;
|
||||||
newResultTypes.push_back(op->getResult(i).getType());
|
newResultTypes.push_back(op->getResult(i).getType());
|
||||||
}
|
}
|
||||||
auto newIf =
|
auto newIf = rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes,
|
||||||
rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes, op.getCondition());
|
op.getCondition());
|
||||||
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
|
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
|
||||||
newIf.getThenRegion().end());
|
newIf.getThenRegion().end());
|
||||||
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
|
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
|
||||||
newIf.getElseRegion().end());
|
newIf.getElseRegion().end());
|
||||||
newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase);
|
newIf.getThenRegion().front().getTerminator()->eraseOperands(
|
||||||
newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase);
|
resultsToErase);
|
||||||
|
newIf.getElseRegion().front().getTerminator()->eraseOperands(
|
||||||
|
resultsToErase);
|
||||||
SmallVector<Value> replacementValues;
|
SmallVector<Value> replacementValues;
|
||||||
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
|
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
|
||||||
if (resultsToErase[i])
|
if (resultsToErase[i])
|
||||||
|
@ -548,8 +552,8 @@ void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (value) {
|
if (value) {
|
||||||
rewriter.eraseOp(op);
|
rewriter.eraseOp(op);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
// Even if the condition is statically false, the assert might never be
|
// Even if the condition is statically false, the assert might never be
|
||||||
// executed.
|
// executed.
|
||||||
|
@ -898,10 +902,10 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
auto rhs = op.getOther();
|
auto rhs = op.getOther();
|
||||||
auto getRhsDevice = rewriter.create<PrimDeviceOp>(op.getLoc(), rhs);
|
auto getRhsDevice = rewriter.create<PrimDeviceOp>(op.getLoc(), rhs);
|
||||||
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
|
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
|
||||||
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
|
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
|
||||||
op, op.getType(), lhs, getRhsDevice.getResult(),
|
op, op.getType(), lhs, getRhsDevice.getResult(),
|
||||||
getRhsDtype.getResult(), op.getNonBlocking(),
|
getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(),
|
||||||
op.getCopy(), op.getMemoryFormat());
|
op.getMemoryFormat());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -996,7 +1000,7 @@ void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
// `aten.max.other` -> `aten.maximum`
|
// `aten.max.other` -> `aten.maximum`
|
||||||
patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) {
|
||||||
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), op.getSelf(),
|
||||||
op.getOther());
|
op.getOther());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -1934,7 +1938,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
|
||||||
// float string representation).
|
// float string representation).
|
||||||
SmallVector<char> buf;
|
SmallVector<char> buf;
|
||||||
getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
|
||||||
/*TruncateZero=*/false);
|
/*TruncateZero=*/false);
|
||||||
auto isValidMLIRIdentifierChar = [](char c) {
|
auto isValidMLIRIdentifierChar = [](char c) {
|
||||||
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
|
||||||
c == '-';
|
c == '-';
|
||||||
|
@ -2045,7 +2049,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
// compiler treat the size as having value semantics?
|
// compiler treat the size as having value semantics?
|
||||||
// There's a small number of such ops, and they are marked as `inplace_view`
|
// There's a small number of such ops, and they are marked as `inplace_view`
|
||||||
// in PyTorch's `native_functions.yaml` file.
|
// in PyTorch's `native_functions.yaml` file.
|
||||||
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(), op.getIdx());
|
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(),
|
||||||
|
op.getIdx());
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -2073,11 +2078,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
||||||
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
|
||||||
auto lhsListConstruct = op.getA().getDefiningOp<Torch::PrimListConstructOp>();
|
auto lhsListConstruct =
|
||||||
|
op.getA().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
|
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto rhsListConstruct = op.getB().getDefiningOp<Torch::PrimListConstructOp>();
|
auto rhsListConstruct =
|
||||||
|
op.getB().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
|
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -2195,7 +2202,8 @@ LogicalResult PrimTupleConstructOp::verify() {
|
||||||
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
|
||||||
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
auto tupleConstruct =
|
||||||
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||||
if (!tupleConstruct)
|
if (!tupleConstruct)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -2245,7 +2253,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns(
|
||||||
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
|
||||||
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
auto tupleConstruct =
|
||||||
|
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
|
||||||
if (!tupleConstruct)
|
if (!tupleConstruct)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -2400,9 +2409,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
|
||||||
// AtenAliasOp
|
// AtenAliasOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); }
|
||||||
return getOperand();
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenFloordivIntOp
|
// AtenFloordivIntOp
|
||||||
|
@ -2481,14 +2488,12 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
int64_t start, end, step;
|
int64_t start, end, step;
|
||||||
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
|
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
|
||||||
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
|
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
|
||||||
matchPattern(getStep(), m_TorchConstantInt(&step))
|
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
|
||||||
&& step == 1
|
start == 0 && end == std::numeric_limits<int64_t>::max())
|
||||||
&& start == 0
|
return getOperand(0);
|
||||||
&& end == std::numeric_limits<int64_t>::max())
|
|
||||||
return getOperand(0);
|
|
||||||
|
|
||||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
||||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
||||||
|
@ -2744,7 +2749,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
// aten.Int.Tensor, fold to the scalar number.
|
// aten.Int.Tensor, fold to the scalar number.
|
||||||
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
if (auto numToTensorScalar = getA().getDefiningOp<PrimNumToTensorScalarOp>())
|
||||||
return numToTensorScalar.getA();
|
return numToTensorScalar.getA();
|
||||||
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
if (auto tensorIntOp = getA().getDefiningOp<AtenTensorIntOp>())
|
||||||
return tensorIntOp.getT();
|
return tensorIntOp.getT();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -2955,7 +2960,6 @@ LogicalResult AtenPermuteOp::verify() {
|
||||||
<< " elements, the output has rank " << outRank << '.';
|
<< " elements, the output has rank " << outRank << '.';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Initialization of the reverse permutation. -1 denotes an unknown
|
// Initialization of the reverse permutation. -1 denotes an unknown
|
||||||
// permutation index.
|
// permutation index.
|
||||||
SmallVector<int64_t> reversePermutation(outRank, -1);
|
SmallVector<int64_t> reversePermutation(outRank, -1);
|
||||||
|
|
|
@ -440,7 +440,7 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
|
||||||
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
|
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
|
||||||
return IntegerType::get(context, integerType.getWidth(),
|
return IntegerType::get(context, integerType.getWidth(),
|
||||||
IntegerType::Signless);
|
IntegerType::Signless);
|
||||||
} else if (dtype.isa<mlir::ComplexType>()){
|
} else if (dtype.isa<mlir::ComplexType>()) {
|
||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -556,9 +556,9 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
||||||
|
|
||||||
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
|
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
|
||||||
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
|
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
|
||||||
// Unfortunately the generated implementations aren't visible/exposed ("static" linkage)
|
// Unfortunately the generated implementations aren't visible/exposed ("static"
|
||||||
// and the predicates themselves can't be added/used in the specification of the parameters
|
// linkage) and the predicates themselves can't be added/used in the
|
||||||
// of the Torch_DictType.
|
// specification of the parameters of the Torch_DictType.
|
||||||
static bool isAnyTorchDictKeyType(Type type) {
|
static bool isAnyTorchDictKeyType(Type type) {
|
||||||
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
|
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
|
||||||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
|
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
|
||||||
|
|
|
@ -355,7 +355,7 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||||
: rhsType.getOptionalDtype();
|
: rhsType.getOptionalDtype();
|
||||||
|
|
||||||
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
||||||
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
||||||
|
@ -457,7 +457,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
||||||
Location loc, Type outType,
|
Location loc, Type outType,
|
||||||
Value input,
|
Value input,
|
||||||
|
@ -1269,7 +1268,8 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp`
|
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
|
||||||
|
// `AtenMinDimOp`
|
||||||
namespace {
|
namespace {
|
||||||
template <typename OpTy, typename DecompOpTy>
|
template <typename OpTy, typename DecompOpTy>
|
||||||
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
|
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
|
||||||
|
@ -1300,9 +1300,9 @@ public:
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
|
|
||||||
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
|
||||||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input
|
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
||||||
// tensor is flattened to 1d tensor and then the reduction happens on the
|
// first the input tensor is flattened to 1d tensor and then the reduction
|
||||||
// 0th dimension.
|
// happens on the 0th dimension.
|
||||||
if (dim.getType().isa<Torch::NoneType>()) {
|
if (dim.getType().isa<Torch::NoneType>()) {
|
||||||
BaseTensorType flattenType =
|
BaseTensorType flattenType =
|
||||||
inputType
|
inputType
|
||||||
|
@ -1317,11 +1317,11 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Value resultArg =
|
Value resultArg =
|
||||||
rewriter
|
rewriter
|
||||||
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
|
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType, input,
|
||||||
input, dim, keepDim)
|
dim, keepDim)
|
||||||
.getIndices();
|
.getIndices();
|
||||||
|
|
||||||
rewriter.replaceOp(op, resultArg);
|
rewriter.replaceOp(op, resultArg);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1959,10 +1959,12 @@ public:
|
||||||
// Define λ and α
|
// Define λ and α
|
||||||
double scale = 1.0507009873554804934193349852946;
|
double scale = 1.0507009873554804934193349852946;
|
||||||
double alpha = 1.6732632423543772848170429916717;
|
double alpha = 1.6732632423543772848170429916717;
|
||||||
|
|
||||||
// Create constants for λ and α
|
// Create constants for λ and α
|
||||||
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(scale));
|
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(alpha));
|
loc, rewriter.getF64FloatAttr(scale));
|
||||||
|
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(
|
||||||
|
loc, rewriter.getF64FloatAttr(alpha));
|
||||||
|
|
||||||
// Create zero tensor for comparison
|
// Create zero tensor for comparison
|
||||||
Value constantZero =
|
Value constantZero =
|
||||||
|
@ -1972,17 +1974,21 @@ public:
|
||||||
// Calculate positive and negative parts
|
// Calculate positive and negative parts
|
||||||
Value constantOne =
|
Value constantOne =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
Value positiveOutput = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
Value positiveOutput =
|
||||||
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
|
||||||
Value minZeroX =
|
Value minZeroX =
|
||||||
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
|
||||||
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
|
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
|
||||||
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
|
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(
|
||||||
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
|
loc, resType, expInput, constantOne, constantOne);
|
||||||
|
Value negativeOutput = rewriter.create<AtenMulScalarOp>(
|
||||||
|
loc, resType, expInputMinusOne, alphaVal);
|
||||||
|
|
||||||
// Multiply the result by λ
|
// Multiply the result by λ
|
||||||
Value seluOutput = rewriter.create<AtenAddTensorOp>(
|
Value seluOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
loc, resType, positiveOutput, negativeOutput, constantOne);
|
loc, resType, positiveOutput, negativeOutput, constantOne);
|
||||||
seluOutput = rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
|
seluOutput =
|
||||||
|
rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
|
||||||
|
|
||||||
// Replace the original operation
|
// Replace the original operation
|
||||||
rewriter.replaceOp(op, seluOutput);
|
rewriter.replaceOp(op, seluOutput);
|
||||||
|
@ -2592,79 +2598,89 @@ public:
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
|
static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
|
||||||
Location loc, Value input,
|
Location loc, Value input,
|
||||||
int64_t dimA, int64_t dimB,
|
int64_t dimA,
|
||||||
Value &transposed) {
|
int64_t dimB,
|
||||||
Type transposedType;
|
Value &transposed) {
|
||||||
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
Type transposedType;
|
||||||
dimA, dimB, transposedType)))
|
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
||||||
return failure();
|
dimA, dimB, transposedType)))
|
||||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
return failure();
|
||||||
loc, rewriter.getI64IntegerAttr(dimA));
|
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||||
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
|
loc, rewriter.getI64IntegerAttr(dimA));
|
||||||
loc, rewriter.getI64IntegerAttr(dimB));
|
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
|
||||||
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
loc, rewriter.getI64IntegerAttr(dimB));
|
||||||
loc, transposedType, input, cstDimA, cstDimB);
|
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
return success();
|
loc, transposedType, input, cstDimA, cstDimB);
|
||||||
}
|
return success();
|
||||||
|
|
||||||
class DecomposeAtenConvTbcOp : public OpRewritePattern<AtenConvTbcOp> {
|
|
||||||
public:
|
|
||||||
using OpRewritePattern::OpRewritePattern;
|
|
||||||
LogicalResult matchAndRewrite(AtenConvTbcOp op,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
Value emptyList = rewriter.create<PrimListConstructOp>(
|
|
||||||
op.getLoc(),
|
|
||||||
Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
||||||
SmallVector<Value>());
|
|
||||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
|
||||||
Value oneList = rewriter.create<PrimListConstructOp>(
|
|
||||||
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
||||||
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(1))});
|
|
||||||
Value padding = rewriter.create<PrimListConstructOp>(
|
|
||||||
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
|
||||||
SmallVector<Value>{op.getPad()});
|
|
||||||
Value groups = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(1));
|
|
||||||
|
|
||||||
// convtbc has WNC layout for input and output
|
|
||||||
// and WCF layout for weight
|
|
||||||
// whereas Convolution is going to use Conv1DNcwFcwOp for 1d
|
|
||||||
// which means we need the inputs in NCW and the weight in FCW
|
|
||||||
Value selfWnc = op.getSelf();
|
|
||||||
Value selfNwc;
|
|
||||||
Value selfNcw;
|
|
||||||
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc");
|
|
||||||
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw");
|
|
||||||
|
|
||||||
Value weightWcf = op.getWeight();
|
|
||||||
Value weightFcw;
|
|
||||||
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw");
|
|
||||||
|
|
||||||
|
|
||||||
Value outputNcw = rewriter.create<AtenConvolutionOp>(
|
|
||||||
op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/oneList,
|
|
||||||
/*padding*/ padding, /*dilation*/ oneList,
|
|
||||||
/*transpose*/ cstFalse, /*output_padding*/ emptyList,
|
|
||||||
groups);
|
|
||||||
|
|
||||||
// convert output from Ncw to Wnc
|
|
||||||
Value outputNwc;
|
|
||||||
Value outputWnc;
|
|
||||||
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc");
|
|
||||||
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc)))
|
|
||||||
return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc");
|
|
||||||
rewriter.replaceOp(op, outputWnc);
|
|
||||||
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class DecomposeAtenConvTbcOp : public OpRewritePattern<AtenConvTbcOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenConvTbcOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value emptyList = rewriter.create<PrimListConstructOp>(
|
||||||
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
|
SmallVector<Value>());
|
||||||
|
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
Value oneList = rewriter.create<PrimListConstructOp>(
|
||||||
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
|
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getI64IntegerAttr(1))});
|
||||||
|
Value padding = rewriter.create<PrimListConstructOp>(
|
||||||
|
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
|
SmallVector<Value>{op.getPad()});
|
||||||
|
Value groups = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
op.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||||
|
|
||||||
|
// convtbc has WNC layout for input and output
|
||||||
|
// and WCF layout for weight
|
||||||
|
// whereas Convolution is going to use Conv1DNcwFcwOp for 1d
|
||||||
|
// which means we need the inputs in NCW and the weight in FCW
|
||||||
|
Value selfWnc = op.getSelf();
|
||||||
|
Value selfNwc;
|
||||||
|
Value selfNcw;
|
||||||
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc,
|
||||||
|
0, 1, selfNwc)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"failed to transpose input to Nwc");
|
||||||
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc,
|
||||||
|
1, 2, selfNcw)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"failed to transpose input to Ncw");
|
||||||
|
|
||||||
|
Value weightWcf = op.getWeight();
|
||||||
|
Value weightFcw;
|
||||||
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
||||||
|
weightWcf, 0, 2, weightFcw)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"failed to transpose weight to Fcw");
|
||||||
|
|
||||||
|
Value outputNcw = rewriter.create<AtenConvolutionOp>(
|
||||||
|
op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(),
|
||||||
|
/*stride*/ oneList,
|
||||||
|
/*padding*/ padding, /*dilation*/ oneList,
|
||||||
|
/*transpose*/ cstFalse, /*output_padding*/ emptyList, groups);
|
||||||
|
|
||||||
|
// convert output from Ncw to Wnc
|
||||||
|
Value outputNwc;
|
||||||
|
Value outputWnc;
|
||||||
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
||||||
|
outputNcw, 1, 2, outputNwc)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"failed to transpose output to Nwc");
|
||||||
|
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
|
||||||
|
outputNwc, 0, 1, outputWnc)))
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"failed to transpose output to Wnc");
|
||||||
|
rewriter.replaceOp(op, outputWnc);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.conv1d to aten.convolution
|
// Decompose aten.conv1d to aten.convolution
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -3815,8 +3831,8 @@ public:
|
||||||
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
|
||||||
Value stdRandN =
|
Value stdRandN =
|
||||||
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
|
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
|
||||||
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
|
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN, mean,
|
||||||
mean, /*alpha=*/one);
|
/*alpha=*/one);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -6654,8 +6670,10 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
addPatternIfTargetOpIsIllegal<
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<
|
||||||
|
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
|
||||||
|
@ -6768,8 +6786,6 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
GreedyRewriteConfig config;
|
GreedyRewriteConfig config;
|
||||||
config.useTopDownTraversal = true;
|
config.useTopDownTraversal = true;
|
||||||
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
config.maxIterations = GreedyRewriteConfig::kNoLimit;
|
||||||
|
|
|
@ -170,8 +170,8 @@ private:
|
||||||
auto attr = std::get<1>(t);
|
auto attr = std::get<1>(t);
|
||||||
nameStack.push_back(attr.getName().str());
|
nameStack.push_back(attr.getName().str());
|
||||||
if (attr.getType().isa<NnModuleType>()) {
|
if (attr.getType().isa<NnModuleType>()) {
|
||||||
if (failed(
|
if (failed(recursivelyTraverse(
|
||||||
recursivelyTraverse(slot.getValue().getDefiningOp<NnModuleOp>())))
|
slot.getValue().getDefiningOp<NnModuleOp>())))
|
||||||
return failure();
|
return failure();
|
||||||
} else if (usedSlots.find(slot) != usedSlots.end()) {
|
} else if (usedSlots.find(slot) != usedSlots.end()) {
|
||||||
// Only create the GlobalSlotOp if the slot is used at all.
|
// Only create the GlobalSlotOp if the slot is used at all.
|
||||||
|
@ -190,8 +190,8 @@ private:
|
||||||
}
|
}
|
||||||
for (auto method : classType.getOps<MethodOp>()) {
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
nameStack.push_back(method.getName().str());
|
nameStack.push_back(method.getName().str());
|
||||||
funcLinkageInfo[{nnModule,
|
funcLinkageInfo[{
|
||||||
symbolTable.lookup<func::FuncOp>(method.getFunction())}] =
|
nnModule, symbolTable.lookup<func::FuncOp>(method.getFunction())}] =
|
||||||
LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()};
|
LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()};
|
||||||
nameStack.pop_back();
|
nameStack.pop_back();
|
||||||
}
|
}
|
||||||
|
@ -501,21 +501,24 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
||||||
|
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
|
||||||
auto instance = mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
auto instance =
|
||||||
|
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
||||||
SlotOp affectedSlot;
|
SlotOp affectedSlot;
|
||||||
for (auto slot : instance.getOps<SlotOp>()) {
|
for (auto slot : instance.getOps<SlotOp>()) {
|
||||||
if (slot.getName() == op.getName())
|
if (slot.getName() == op.getName())
|
||||||
affectedSlot = slot;
|
affectedSlot = slot;
|
||||||
}
|
}
|
||||||
OpBuilder(op).create<GlobalSlotSetOp>(
|
OpBuilder(op).create<GlobalSlotSetOp>(
|
||||||
op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(),
|
op.getLoc(),
|
||||||
|
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(),
|
||||||
op.getValue());
|
op.getValue());
|
||||||
toErase.push_back(op);
|
toErase.push_back(op);
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
};
|
};
|
||||||
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
||||||
if (!op.getType().isa<NnModuleType>()) {
|
if (!op.getType().isa<NnModuleType>()) {
|
||||||
auto instance = mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
auto instance =
|
||||||
|
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
||||||
SlotOp affectedSlot;
|
SlotOp affectedSlot;
|
||||||
for (auto slot : instance.getOps<SlotOp>()) {
|
for (auto slot : instance.getOps<SlotOp>()) {
|
||||||
if (slot.getName() == op.getName())
|
if (slot.getName() == op.getName())
|
||||||
|
|
|
@ -163,7 +163,8 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
|
||||||
}
|
}
|
||||||
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
|
||||||
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
|
||||||
getProgramPoint<FlatSymbolRefProgramPoint>(globalSlotSet.getSlotAttr()));
|
getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
|
globalSlotSet.getSlotAttr()));
|
||||||
propagateIfChanged(state, state->setSafe(false));
|
propagateIfChanged(state, state->setSafe(false));
|
||||||
}
|
}
|
||||||
// Save the InitializeGlobalSlotsOp for later referencee
|
// Save the InitializeGlobalSlotsOp for later referencee
|
||||||
|
@ -211,8 +212,8 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
|
||||||
auto it =
|
auto it =
|
||||||
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
|
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
|
||||||
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
|
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
|
||||||
Value value = initializeGlobalSlotsOp->getOperand(
|
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
|
||||||
std::distance(initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
|
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
|
||||||
auto *flatSymbolRefState =
|
auto *flatSymbolRefState =
|
||||||
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
|
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
|
||||||
flatSymbolRefPoint);
|
flatSymbolRefPoint);
|
||||||
|
@ -331,7 +332,8 @@ class InlineGlobalSlotsPass
|
||||||
|
|
||||||
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
||||||
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
||||||
auto slotSymName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
auto slotSymName =
|
||||||
|
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||||
Value operand = initialize.getOperand(i);
|
Value operand = initialize.getOperand(i);
|
||||||
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>());
|
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>());
|
||||||
|
@ -405,7 +407,8 @@ class InlineGlobalSlotsPass
|
||||||
SmallVector<Attribute> newSlotSymNames;
|
SmallVector<Attribute> newSlotSymNames;
|
||||||
SmallVector<Value> newInitialValues;
|
SmallVector<Value> newInitialValues;
|
||||||
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
||||||
auto slotSymName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
auto slotSymName =
|
||||||
|
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
||||||
if (!safeToInline.count(slotSymName)) {
|
if (!safeToInline.count(slotSymName)) {
|
||||||
newSlotSymNames.push_back(slotSymName);
|
newSlotSymNames.push_back(slotSymName);
|
||||||
newInitialValues.push_back(initialize.getOperand(i));
|
newInitialValues.push_back(initialize.getOperand(i));
|
||||||
|
|
|
@ -202,15 +202,16 @@ static bool satisfiesBackendContract(ModuleOp module,
|
||||||
// Check for unimplemented operators first to give more direct diagnostics.
|
// Check for unimplemented operators first to give more direct diagnostics.
|
||||||
walkResult0 = module.walk([&](Torch::OperatorOp op) {
|
walkResult0 = module.walk([&](Torch::OperatorOp op) {
|
||||||
if (llvm::all_of(op.getResults(), [&op](auto res) {
|
if (llvm::all_of(op.getResults(), [&op](auto res) {
|
||||||
return succeeded(
|
return succeeded(checkType(op.getOperation(), res.getType(),
|
||||||
checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false));
|
/*actuallyEmitDiagnostics=*/false));
|
||||||
})) {
|
})) {
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (actuallyEmitDiagnostics) {
|
if (actuallyEmitDiagnostics) {
|
||||||
op->emitError("unsupported by backend contract: Unimplemented operator '"
|
op->emitError(
|
||||||
+ op.getName() + "'");
|
"unsupported by backend contract: Unimplemented operator '" +
|
||||||
|
op.getName() + "'");
|
||||||
}
|
}
|
||||||
return WalkResult::interrupt();
|
return WalkResult::interrupt();
|
||||||
});
|
});
|
||||||
|
@ -309,20 +310,22 @@ public:
|
||||||
<< " iterations of the simplification pipeline\n";
|
<< " iterations of the simplification pipeline\n";
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llvm::StringSet<> backendLegalOpsSet;
|
llvm::StringSet<> backendLegalOpsSet;
|
||||||
};
|
};
|
||||||
|
|
||||||
class VerifyBackendContractNoDecompositionsPass
|
class VerifyBackendContractNoDecompositionsPass
|
||||||
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
|
: public VerifyBackendContractNoDecompositionsBase<
|
||||||
|
VerifyBackendContractNoDecompositionsPass> {
|
||||||
public:
|
public:
|
||||||
VerifyBackendContractNoDecompositionsPass() = default;
|
VerifyBackendContractNoDecompositionsPass() = default;
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
MLIRContext *context = &getContext();
|
MLIRContext *context = &getContext();
|
||||||
ConversionTarget target =
|
ConversionTarget target =
|
||||||
getBackendContractTarget(context, /*decompose*/false,
|
getBackendContractTarget(context, /*decompose*/ false,
|
||||||
/*backendLegalOpsSet*/{});
|
/*backendLegalOpsSet*/ {});
|
||||||
|
|
||||||
if (!satisfiesBackendContract(getOperation(), target,
|
if (!satisfiesBackendContract(getOperation(), target,
|
||||||
/*actuallyEmitDiagnostics=*/true)) {
|
/*actuallyEmitDiagnostics=*/true)) {
|
||||||
|
|
|
@ -158,9 +158,11 @@ void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FailureOr<Value> Torch::adjustFunctionArg(
|
FailureOr<Value>
|
||||||
OpBuilder &b, Location loc, Value operand, Type desiredType,
|
Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation) {
|
Type desiredType,
|
||||||
|
function_ref<Value(OpBuilder &, Location, Value, Type)>
|
||||||
|
baseTransformation) {
|
||||||
operand = baseTransformation(b, loc, operand, desiredType);
|
operand = baseTransformation(b, loc, operand, desiredType);
|
||||||
|
|
||||||
// No need for adjustment if they already match.
|
// No need for adjustment if they already match.
|
||||||
|
|
|
@ -90,7 +90,8 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<std::optional<int64_t>> ranks;
|
SmallVector<std::optional<int64_t>> ranks;
|
||||||
SmallVector<int64_t> dtypes;
|
SmallVector<int64_t> dtypes;
|
||||||
if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) {
|
if (!matchPattern(op.getRanks(),
|
||||||
|
m_TorchListOfOptionalConstantInts(ranks))) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected `ranks` to be a list of optional constant ints");
|
op, "Expected `ranks` to be a list of optional constant ints");
|
||||||
}
|
}
|
||||||
|
|
|
@ -344,9 +344,9 @@ FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
||||||
// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
|
// Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If
|
||||||
// yes, then computes the final broadcast shape.
|
// yes, then computes the final broadcast shape.
|
||||||
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
|
void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
|
||||||
Value inputA, Value inputB,
|
Value inputA, Value inputB,
|
||||||
SmallVector<int64_t> &resultShape,
|
SmallVector<int64_t> &resultShape,
|
||||||
SmallVector<Value> &resultShapeValue) {
|
SmallVector<Value> &resultShapeValue) {
|
||||||
SmallVector<int64_t> shapeA{
|
SmallVector<int64_t> shapeA{
|
||||||
inputA.getType().cast<BaseTensorType>().getSizes()};
|
inputA.getType().cast<BaseTensorType>().getSizes()};
|
||||||
SmallVector<int64_t> shapeB{
|
SmallVector<int64_t> shapeB{
|
||||||
|
@ -514,7 +514,7 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
|
LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA,
|
||||||
int64_t dimB, Type &transposedType) {
|
int64_t dimB, Type &transposedType) {
|
||||||
if (!inType.hasSizes())
|
if (!inType.hasSizes())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<int64_t> shape(inType.getSizes());
|
SmallVector<int64_t> shape(inType.getSizes());
|
||||||
|
|
|
@ -54,14 +54,14 @@ void TorchConversionDialect::initialize() {
|
||||||
addInterfaces<TorchConversionInlinerInterface>();
|
addInterfaces<TorchConversionInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant materializer.
|
// Constant materializer.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
|
Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
|
||||||
Attribute value, Type type,
|
Attribute value,
|
||||||
Location loc) {
|
Type type,
|
||||||
|
Location loc) {
|
||||||
if (auto integerType = type.dyn_cast<Torch::IntType>())
|
if (auto integerType = type.dyn_cast<Torch::IntType>())
|
||||||
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
|
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
||||||
typeConverter.addConversion([](Torch::BoolType type) -> std::optional<Type> {
|
typeConverter.addConversion([](Torch::BoolType type) -> std::optional<Type> {
|
||||||
return IntegerType::get(type.getContext(), 1);
|
return IntegerType::get(type.getContext(), 1);
|
||||||
});
|
});
|
||||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
typeConverter.addTargetMaterialization(
|
||||||
IntegerType type, ValueRange inputs,
|
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
|
||||||
Location loc) -> std::optional<Value> {
|
Location loc) -> std::optional<Value> {
|
||||||
// Other builtin integer types could be handled by other materializers.
|
// Other builtin integer types could be handled by other materializers.
|
||||||
if (!(type.getWidth() == 1 && type.isSignless()))
|
if (!(type.getWidth() == 1 && type.isSignless()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<Torch::BoolType>());
|
assert(inputs[0].getType().isa<Torch::BoolType>());
|
||||||
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
|
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
||||||
typeConverter.addConversion([](Torch::IntType type) -> std::optional<Type> {
|
typeConverter.addConversion([](Torch::IntType type) -> std::optional<Type> {
|
||||||
return IntegerType::get(type.getContext(), 64);
|
return IntegerType::get(type.getContext(), 64);
|
||||||
});
|
});
|
||||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
typeConverter.addTargetMaterialization(
|
||||||
IntegerType type, ValueRange inputs,
|
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
|
||||||
Location loc) -> std::optional<Value> {
|
Location loc) -> std::optional<Value> {
|
||||||
// Other builtin integer types could be handled by other materializers.
|
// Other builtin integer types could be handled by other materializers.
|
||||||
if (!(type.getWidth() == 64 && type.isSignless()))
|
if (!(type.getWidth() == 64 && type.isSignless()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
// Other input type to be converted to i64 are handled by other
|
// Other input type to be converted to i64 are handled by other
|
||||||
// materializers.
|
// materializers.
|
||||||
if (!inputs[0].getType().isa<Torch::IntType>())
|
if (!inputs[0].getType().isa<Torch::IntType>())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
||||||
typeConverter.addConversion([](Torch::FloatType type) -> std::optional<Type> {
|
typeConverter.addConversion([](Torch::FloatType type) -> std::optional<Type> {
|
||||||
return Float64Type::get(type.getContext());
|
return Float64Type::get(type.getContext());
|
||||||
});
|
});
|
||||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
typeConverter.addTargetMaterialization(
|
||||||
Float64Type type, ValueRange inputs,
|
[](OpBuilder &builder, Float64Type type, ValueRange inputs,
|
||||||
Location loc) -> std::optional<Value> {
|
Location loc) -> std::optional<Value> {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<Torch::FloatType>());
|
assert(inputs[0].getType().isa<Torch::FloatType>());
|
||||||
return builder.create<ToF64Op>(loc, inputs[0]).getResult();
|
return builder.create<ToF64Op>(loc, inputs[0]).getResult();
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
@ -133,22 +133,23 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
||||||
TypeConverter &typeConverter) {
|
TypeConverter &typeConverter) {
|
||||||
target.addLegalOp<TorchConversion::GeneratorToI64Op,
|
target.addLegalOp<TorchConversion::GeneratorToI64Op,
|
||||||
TorchConversion::I64ToGeneratorOp>();
|
TorchConversion::I64ToGeneratorOp>();
|
||||||
typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional<Type> {
|
typeConverter.addConversion(
|
||||||
return IntegerType::get(type.getContext(), 64);
|
[](Torch::GeneratorType type) -> std::optional<Type> {
|
||||||
});
|
return IntegerType::get(type.getContext(), 64);
|
||||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
});
|
||||||
IntegerType type, ValueRange inputs,
|
typeConverter.addTargetMaterialization(
|
||||||
Location loc) -> std::optional<Value> {
|
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
|
||||||
// Other builtin integer types could be handled by other materializers.
|
Location loc) -> std::optional<Value> {
|
||||||
if (!(type.getWidth() == 64 && type.isSignless()))
|
// Other builtin integer types could be handled by other materializers.
|
||||||
return std::nullopt;
|
if (!(type.getWidth() == 64 && type.isSignless()))
|
||||||
// Other input type to be converted to i64 are handled by other
|
return std::nullopt;
|
||||||
// materializers.
|
// Other input type to be converted to i64 are handled by other
|
||||||
if (!inputs[0].getType().isa<Torch::GeneratorType>())
|
// materializers.
|
||||||
return std::nullopt;
|
if (!inputs[0].getType().isa<Torch::GeneratorType>())
|
||||||
assert(inputs.size() == 1);
|
return std::nullopt;
|
||||||
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
assert(inputs.size() == 1);
|
||||||
});
|
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
||||||
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
|
|
|
@ -18,8 +18,8 @@
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
|
||||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||||
|
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::torch;
|
using namespace mlir::torch;
|
||||||
|
@ -65,7 +65,8 @@ public:
|
||||||
|
|
||||||
auto getConstantIntegerFromDefiningOp = [](Value operand,
|
auto getConstantIntegerFromDefiningOp = [](Value operand,
|
||||||
int &extractedInt) {
|
int &extractedInt) {
|
||||||
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
|
auto castOp =
|
||||||
|
dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
|
||||||
if (!castOp) {
|
if (!castOp) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -83,7 +84,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
int unpackedBitWidth;
|
int unpackedBitWidth;
|
||||||
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) {
|
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth,
|
||||||
|
unpackedBitWidth))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (unpackedBitWidth !=
|
if (unpackedBitWidth !=
|
||||||
|
@ -103,32 +105,35 @@ public:
|
||||||
// expand lhs
|
// expand lhs
|
||||||
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
|
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
|
||||||
lhsReductDimSize / gs, gs};
|
lhsReductDimSize / gs, gs};
|
||||||
RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType);
|
RankedTensorType lhsExpandedType =
|
||||||
|
RankedTensorType::get(lhsExpandedShape, elementType);
|
||||||
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
|
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
|
||||||
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
loc, lhsExpandedType, lhs, lhsReassociation);
|
loc, lhsExpandedType, lhs, lhsReassociation);
|
||||||
|
|
||||||
// expand rhs
|
// expand rhs
|
||||||
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs};
|
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs,
|
||||||
RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType);
|
gs};
|
||||||
|
RankedTensorType rhsExpandedType =
|
||||||
|
RankedTensorType::get(rhsExpandedShape, rhsElementType);
|
||||||
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
|
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
|
||||||
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
|
||||||
loc, rhsExpandedType, rhsQuant, rhsReassociation);
|
loc, rhsExpandedType, rhsQuant, rhsReassociation);
|
||||||
Value cst0 = rewriter.create<arith::ConstantOp>(
|
Value cst0 = rewriter.create<arith::ConstantOp>(
|
||||||
loc, FloatAttr::get(elementType, 0.0));
|
loc, FloatAttr::get(elementType, 0.0));
|
||||||
|
|
||||||
Value emptyDequant = rewriter.create<tensor::EmptyOp>(
|
Value emptyDequant =
|
||||||
loc, rhsExpandedShape, elementType);
|
rewriter.create<tensor::EmptyOp>(loc, rhsExpandedShape, elementType);
|
||||||
SmallVector<Value> dynDims;
|
SmallVector<Value> dynDims;
|
||||||
for (int i = 0; i < lhsType.getRank(); i++) {
|
for (int i = 0; i < lhsType.getRank(); i++) {
|
||||||
if (lhsType.isDynamicDim(i)) {
|
if (lhsType.isDynamicDim(i)) {
|
||||||
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
|
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Value empty = rewriter.create<tensor::EmptyOp>(
|
Value empty = rewriter.create<tensor::EmptyOp>(loc, resultShape,
|
||||||
loc, resultShape, elementType, dynDims);
|
elementType, dynDims);
|
||||||
Value output = rewriter.create<linalg::FillOp>(
|
Value output =
|
||||||
loc, cst0, empty).getResult(0);
|
rewriter.create<linalg::FillOp>(loc, cst0, empty).getResult(0);
|
||||||
|
|
||||||
AffineExpr d0, d1, d2, d3, d4;
|
AffineExpr d0, d1, d2, d3, d4;
|
||||||
bindDims(getContext(), d0, d1, d2, d3, d4);
|
bindDims(getContext(), d0, d1, d2, d3, d4);
|
||||||
|
@ -141,12 +146,12 @@ public:
|
||||||
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
|
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
|
||||||
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
|
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
|
||||||
|
|
||||||
SmallVector<utils::IteratorType> dequantIteratorTypes(3, utils::IteratorType::parallel);
|
SmallVector<utils::IteratorType> dequantIteratorTypes(
|
||||||
|
3, utils::IteratorType::parallel);
|
||||||
SmallVector<utils::IteratorType> matmulIteratorTypes = {
|
SmallVector<utils::IteratorType> matmulIteratorTypes = {
|
||||||
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
utils::IteratorType::parallel, utils::IteratorType::parallel,
|
||||||
utils::IteratorType::parallel, utils::IteratorType::reduction,
|
utils::IteratorType::parallel, utils::IteratorType::reduction,
|
||||||
utils::IteratorType::reduction
|
utils::IteratorType::reduction};
|
||||||
};
|
|
||||||
|
|
||||||
Value rhsDequant =
|
Value rhsDequant =
|
||||||
rewriter
|
rewriter
|
||||||
|
@ -157,9 +162,12 @@ public:
|
||||||
/*iteratorTypes=*/dequantIteratorTypes,
|
/*iteratorTypes=*/dequantIteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
Value w = args[0], scale = args[1], zeroPoint = args[2];
|
Value w = args[0], scale = args[1], zeroPoint = args[2];
|
||||||
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
|
Value extw =
|
||||||
Value fp_extw = b.create<arith::UIToFPOp>(loc, rewriter.getF16Type(), extw);
|
b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
|
||||||
Value shifted = b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
|
Value fp_extw = b.create<arith::UIToFPOp>(
|
||||||
|
loc, rewriter.getF16Type(), extw);
|
||||||
|
Value shifted =
|
||||||
|
b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
|
||||||
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
|
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
|
||||||
b.create<linalg::YieldOp>(loc, dqw);
|
b.create<linalg::YieldOp>(loc, dqw);
|
||||||
})
|
})
|
||||||
|
@ -168,8 +176,8 @@ public:
|
||||||
Value matmulDequant =
|
Value matmulDequant =
|
||||||
rewriter
|
rewriter
|
||||||
.create<linalg::GenericOp>(
|
.create<linalg::GenericOp>(
|
||||||
loc, output.getType(),
|
loc, output.getType(), ValueRange{lhsExpanded, rhsDequant},
|
||||||
ValueRange{lhsExpanded, rhsDequant}, output,
|
output,
|
||||||
/*indexingMaps=*/matIndexingMaps,
|
/*indexingMaps=*/matIndexingMaps,
|
||||||
/*iteratorTypes=*/matmulIteratorTypes,
|
/*iteratorTypes=*/matmulIteratorTypes,
|
||||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
@ -188,7 +196,8 @@ public:
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertCustomQuantOpPass
|
class ConvertCustomQuantOpPass
|
||||||
: public TorchConversion::ConvertCustomQuantOpBase<ConvertCustomQuantOpPass> {
|
: public TorchConversion::ConvertCustomQuantOpBase<
|
||||||
|
ConvertCustomQuantOpPass> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<arith::ArithDialect>();
|
registry.insert<arith::ArithDialect>();
|
||||||
registry.insert<func::FuncDialect>();
|
registry.insert<func::FuncDialect>();
|
||||||
|
@ -213,8 +222,8 @@ class ConvertCustomQuantOpPass
|
||||||
target.addIllegalOp<OperatorOp>();
|
target.addIllegalOp<OperatorOp>();
|
||||||
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
|
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
|
||||||
|
|
||||||
if (failed(
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
applyPartialConversion(getOperation(), target, std::move(patterns))))
|
std::move(patterns))))
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -33,7 +33,6 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::TorchConversion;
|
using namespace mlir::torch::TorchConversion;
|
||||||
using namespace TMTensor;
|
using namespace TMTensor;
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class VerifyLinalgOnTensorsBackendContractPass
|
class VerifyLinalgOnTensorsBackendContractPass
|
||||||
: public VerifyLinalgOnTensorsBackendContractBase<
|
: public VerifyLinalgOnTensorsBackendContractBase<
|
||||||
|
@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
||||||
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
|
||||||
// doesn't unnecessarily spew out the entire module.
|
// doesn't unnecessarily spew out the entire module.
|
||||||
emitError(module.getLoc())
|
emitError(module.getLoc())
|
||||||
<< "Module does not conform to the linalg-on-tensors backend contract. "
|
<< "Module does not conform to the linalg-on-tensors backend "
|
||||||
|
"contract. "
|
||||||
"See dialect conversion legality information above.";
|
"See dialect conversion legality information above.";
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,8 @@ class VerifyStablehloBackendContractPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
|
|
||||||
// Structural operations.
|
// Structural operations.
|
||||||
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
|
||||||
|
opHasLegalTypes);
|
||||||
// Shape operations.
|
// Shape operations.
|
||||||
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);
|
||||||
|
|
||||||
|
|
|
@ -31,18 +31,18 @@ TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape)
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
}
|
}
|
||||||
TorchMlirBackendData::TorchMlirBackendData(
|
TorchMlirBackendData::TorchMlirBackendData(
|
||||||
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
|
BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info)
|
||||||
: BackendData(device, shape), info_(info) {
|
: BackendData(device, shape), info_(info) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
}
|
}
|
||||||
TorchMlirBackendData::TorchMlirBackendData(
|
TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar,
|
||||||
const at::Scalar& scalar, BackendDevice device)
|
BackendDevice device)
|
||||||
: BackendData(device, Shape(scalar.type(), {})),
|
: BackendData(device, Shape(scalar.type(), {})),
|
||||||
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
|
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
}
|
}
|
||||||
TorchMlirBackendData::TorchMlirBackendData(
|
TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor,
|
||||||
const at::Tensor& tensor, BackendDevice device, Shape shape)
|
BackendDevice device, Shape shape)
|
||||||
: BackendData(device, shape),
|
: BackendData(device, shape),
|
||||||
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
|
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
@ -52,19 +52,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() {
|
||||||
return reinterpret_cast<int64_t>(this);
|
return reinterpret_cast<int64_t>(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchMlirBackendData::Assign(const BackendData& data) {
|
void TorchMlirBackendData::Assign(const BackendData &data) {
|
||||||
const TorchMlirBackendData* torch_mlir_data =
|
const TorchMlirBackendData *torch_mlir_data =
|
||||||
dynamic_cast<const TorchMlirBackendData*>(&data);
|
dynamic_cast<const TorchMlirBackendData *>(&data);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(torch_mlir_data,
|
||||||
torch_mlir_data,
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
||||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
|
||||||
|
|
||||||
info_ = torch_mlir_data->info_;
|
info_ = torch_mlir_data->info_;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
|
bool TorchMlirBackendData::HasValue() const { return bool(info_); }
|
||||||
|
|
||||||
BackendData::Info* TorchMlirBackendData::mlir_info() const {
|
BackendData::Info *TorchMlirBackendData::mlir_info() const {
|
||||||
return info_.get();
|
return info_.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -77,8 +76,8 @@ void TorchMlirBackendImpl::PrepareToExit() const {}
|
||||||
* IR Tracing
|
* IR Tracing
|
||||||
* */
|
* */
|
||||||
|
|
||||||
const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
|
const IrBuilder *TorchMlirBackendImpl::GetIrBuilder() const {
|
||||||
static const IrBuilder* builder = new TorchMlirIrBuilder();
|
static const IrBuilder *builder = new TorchMlirIrBuilder();
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,28 +86,29 @@ const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const {
|
||||||
* */
|
* */
|
||||||
|
|
||||||
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor(
|
||||||
const at::Tensor& tensor, const Shape& shape,
|
const at::Tensor &tensor, const Shape &shape,
|
||||||
const BackendDevice& device) const {
|
const BackendDevice &device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
|
return std::make_shared<TorchMlirBackendData>(tensor, device, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
|
BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
|
||||||
const at::Scalar& scalar, const BackendDevice& device) const {
|
const at::Scalar &scalar, const BackendDevice &device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<TorchMlirBackendData>(scalar, device);
|
return std::make_shared<TorchMlirBackendData>(scalar, device);
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
|
BackendDataPtr
|
||||||
const BackendDevice& device, const Shape& shape) const {
|
TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device,
|
||||||
|
const Shape &shape) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_shared<TorchMlirBackendData>(device, shape);
|
return std::make_shared<TorchMlirBackendData>(device, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendDataPtr
|
BackendDataPtr
|
||||||
TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const {
|
TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
const auto* device_data_node = dynamic_cast<const DeviceData*>(node);
|
const auto *device_data_node = dynamic_cast<const DeviceData *>(node);
|
||||||
if (!device_data_node) {
|
if (!device_data_node) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -120,14 +120,13 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||||
c10::optional<at::ScalarType> logical_scalar_type) const {
|
c10::optional<at::ScalarType> logical_scalar_type) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
TorchMlirBackendData* torch_mlir_data =
|
TorchMlirBackendData *torch_mlir_data =
|
||||||
dynamic_cast<TorchMlirBackendData*>(data.get());
|
dynamic_cast<TorchMlirBackendData *>(data.get());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(torch_mlir_data,
|
||||||
torch_mlir_data,
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
||||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
|
|
||||||
|
|
||||||
TorchMlirBackendData::Info* info =
|
TorchMlirBackendData::Info *info =
|
||||||
dynamic_cast<TorchMlirBackendData::Info*>(torch_mlir_data->mlir_info());
|
dynamic_cast<TorchMlirBackendData::Info *>(torch_mlir_data->mlir_info());
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
info,
|
info,
|
||||||
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
"Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info.");
|
||||||
|
@ -140,17 +139,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
|
||||||
* */
|
* */
|
||||||
|
|
||||||
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string &name, BackendDevice device,
|
||||||
c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
|
c10::ArrayRef<const Node *> post_order,
|
||||||
|
Util::EmissionMap emit_status) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_unique<TorchMlirLoweringContext>(
|
return std::make_unique<TorchMlirLoweringContext>(
|
||||||
name, std::forward<BackendDevice>(device),
|
name, std::forward<BackendDevice>(device),
|
||||||
std::forward<c10::ArrayRef<const Node*>>(post_order),
|
std::forward<c10::ArrayRef<const Node *>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status));
|
std::forward<Util::EmissionMap>(emit_status));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
|
std::unique_ptr<LoweringContext>
|
||||||
const std::string& name, BackendDevice device) const {
|
TorchMlirBackendImpl::CreateLoweringContext(const std::string &name,
|
||||||
|
BackendDevice device) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return std::make_unique<TorchMlirLoweringContext>(
|
return std::make_unique<TorchMlirLoweringContext>(
|
||||||
name, std::forward<BackendDevice>(device));
|
name, std::forward<BackendDevice>(device));
|
||||||
|
@ -175,9 +176,8 @@ at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
|
||||||
// Query all available backend devices
|
// Query all available backend devices
|
||||||
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
|
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
return {
|
return {GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
||||||
GetBackendDevice(c10::Device(c10::kLazy, 0)),
|
GetBackendDevice(c10::Device(c10::kCPU, 0))};
|
||||||
GetBackendDevice(c10::Device(c10::kCPU, 0))};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map a particular c10:: device to a concrete backend device
|
// Map a particular c10:: device to a concrete backend device
|
||||||
|
|
|
@ -41,27 +41,28 @@ public:
|
||||||
name = ss.str();
|
name = ss.str();
|
||||||
++i;
|
++i;
|
||||||
}
|
}
|
||||||
Info(const Info& other)
|
Info(const Info &other)
|
||||||
: tensor{other.tensor}, scalar{other.scalar},
|
: tensor{other.tensor}, scalar{other.scalar},
|
||||||
requires_grad{other.requires_grad}, name{other.name} {}
|
requires_grad{other.requires_grad}, name{other.name} {}
|
||||||
Info(const at::Tensor& tensor)
|
Info(const at::Tensor &tensor)
|
||||||
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
|
: tensor{tensor}, requires_grad{tensor.requires_grad()} {}
|
||||||
Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {}
|
Info(const at::Scalar &scalar) : scalar{scalar}, requires_grad(false) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
TorchMlirBackendData(BackendDevice device, Shape shape);
|
TorchMlirBackendData(BackendDevice device, Shape shape);
|
||||||
TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info);
|
TorchMlirBackendData(BackendDevice device, Shape shape,
|
||||||
TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device);
|
std::shared_ptr<BackendData::Info> info);
|
||||||
TorchMlirBackendData(
|
TorchMlirBackendData(const at::Scalar &scalar, BackendDevice device);
|
||||||
const at::Tensor& tensor, BackendDevice device, Shape shape);
|
TorchMlirBackendData(const at::Tensor &tensor, BackendDevice device,
|
||||||
|
Shape shape);
|
||||||
|
|
||||||
virtual BackendData::Handle GetHandle() override;
|
virtual BackendData::Handle GetHandle() override;
|
||||||
|
|
||||||
virtual void Assign(const BackendData& data) override;
|
virtual void Assign(const BackendData &data) override;
|
||||||
|
|
||||||
virtual bool HasValue() const override;
|
virtual bool HasValue() const override;
|
||||||
|
|
||||||
BackendData::Info* mlir_info() const;
|
BackendData::Info *mlir_info() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::shared_ptr<BackendData::Info> info_;
|
std::shared_ptr<BackendData::Info> info_;
|
||||||
|
@ -80,7 +81,7 @@ public:
|
||||||
* IR Tracing
|
* IR Tracing
|
||||||
* */
|
* */
|
||||||
|
|
||||||
const IrBuilder* GetIrBuilder() const override;
|
const IrBuilder *GetIrBuilder() const override;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Configuration
|
* Configuration
|
||||||
|
@ -91,19 +92,22 @@ public:
|
||||||
* Data Transfer
|
* Data Transfer
|
||||||
* */
|
* */
|
||||||
|
|
||||||
virtual BackendDataPtr MakeComputationDataFromTensor(
|
virtual BackendDataPtr
|
||||||
const at::Tensor& tensor, const Shape& shape,
|
MakeComputationDataFromTensor(const at::Tensor &tensor, const Shape &shape,
|
||||||
const BackendDevice& device) const override;
|
const BackendDevice &device) const override;
|
||||||
|
|
||||||
virtual BackendDataPtr MakeComputationDataFromScalar(
|
virtual BackendDataPtr
|
||||||
const at::Scalar& scalar, const BackendDevice& device) const override;
|
MakeComputationDataFromScalar(const at::Scalar &scalar,
|
||||||
|
const BackendDevice &device) const override;
|
||||||
|
|
||||||
virtual BackendDataPtr CreateDataPlaceholder(
|
virtual BackendDataPtr
|
||||||
const BackendDevice& device, const Shape& shape) const override;
|
CreateDataPlaceholder(const BackendDevice &device,
|
||||||
|
const Shape &shape) const override;
|
||||||
|
|
||||||
// Gets backend data if the node is a device data node. Otherwise returns
|
// Gets backend data if the node is a device data node. Otherwise returns
|
||||||
// nullptr.
|
// nullptr.
|
||||||
virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override;
|
virtual BackendDataPtr
|
||||||
|
GetComputationDataFromNode(const Node *) const override;
|
||||||
|
|
||||||
virtual at::Tensor MakeTensorFromComputationData(
|
virtual at::Tensor MakeTensorFromComputationData(
|
||||||
const BackendDataPtr data,
|
const BackendDataPtr data,
|
||||||
|
@ -113,13 +117,14 @@ public:
|
||||||
* Lowering, Compilation, Execution
|
* Lowering, Compilation, Execution
|
||||||
* */
|
* */
|
||||||
|
|
||||||
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext>
|
||||||
const std::string& name, BackendDevice device,
|
CreateLoweringContext(const std::string &name, BackendDevice device,
|
||||||
c10::ArrayRef<const Node*> post_order,
|
c10::ArrayRef<const Node *> post_order,
|
||||||
Util::EmissionMap emit_status) const override;
|
Util::EmissionMap emit_status) const override;
|
||||||
|
|
||||||
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
|
virtual std::unique_ptr<LoweringContext>
|
||||||
const std::string& name, BackendDevice device) const override;
|
CreateLoweringContext(const std::string &name,
|
||||||
|
BackendDevice device) const override;
|
||||||
|
|
||||||
// TODO(whc) need to keep this?
|
// TODO(whc) need to keep this?
|
||||||
// virtual std::vector<std::string> GetCompilationDevices(
|
// virtual std::vector<std::string> GetCompilationDevices(
|
||||||
|
|
|
@ -16,20 +16,18 @@ namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
|
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
|
||||||
: TorchMlirNode(
|
: TorchMlirNode(op, operands, /*num_outputs=*/1,
|
||||||
op, operands, /*num_outputs=*/1,
|
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
|
||||||
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
|
|
||||||
|
|
||||||
std::string DimensionNode::ToString() const { return "DimensionNode"; }
|
std::string DimensionNode::ToString() const { return "DimensionNode"; }
|
||||||
|
|
||||||
SizeNode::SizeNode(Value input, size_t dim)
|
SizeNode::SizeNode(Value input, size_t dim)
|
||||||
: DimensionNode(
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
|
||||||
OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
|
MHash(dim)),
|
||||||
MHash(dim)),
|
|
||||||
dim_(dim){};
|
dim_(dim){};
|
||||||
|
|
||||||
int64_t SizeNode::getStaticValue() const {
|
int64_t SizeNode::getStaticValue() const {
|
||||||
return dynamic_cast<const TorchMlirNode*>(operand(0).node)
|
return dynamic_cast<const TorchMlirNode *>(operand(0).node)
|
||||||
->shape(0)
|
->shape(0)
|
||||||
.size(dim_);
|
.size(dim_);
|
||||||
}
|
}
|
||||||
|
@ -40,8 +38,9 @@ SizeAdd::SizeAdd(Value a, Value b)
|
||||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
|
||||||
|
|
||||||
int64_t SizeAdd::getStaticValue() const {
|
int64_t SizeAdd::getStaticValue() const {
|
||||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() +
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
||||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
->getStaticValue() +
|
||||||
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SizeAdd::ToString() const { return "SizeAdd"; }
|
std::string SizeAdd::ToString() const { return "SizeAdd"; }
|
||||||
|
@ -50,8 +49,9 @@ SizeMul::SizeMul(Value a, Value b)
|
||||||
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
|
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
|
||||||
|
|
||||||
int64_t SizeMul::getStaticValue() const {
|
int64_t SizeMul::getStaticValue() const {
|
||||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() *
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
||||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
->getStaticValue() *
|
||||||
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SizeMul::ToString() const { return "SizeMul"; }
|
std::string SizeMul::ToString() const { return "SizeMul"; }
|
||||||
|
@ -61,11 +61,12 @@ SizeDiv::SizeDiv(Value a, Value b)
|
||||||
|
|
||||||
int64_t SizeDiv::getStaticValue() const {
|
int64_t SizeDiv::getStaticValue() const {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue() !=
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue() !=
|
||||||
0,
|
0,
|
||||||
"Can't divide a dimension by zero");
|
"Can't divide a dimension by zero");
|
||||||
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() /
|
return dynamic_cast<const DimensionNode *>(operand(0).node)
|
||||||
dynamic_cast<const DimensionNode*>(operand(1).node)->getStaticValue();
|
->getStaticValue() /
|
||||||
|
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SizeDiv::ToString() const { return "SizeDiv"; }
|
std::string SizeDiv::ToString() const { return "SizeDiv"; }
|
||||||
|
|
|
@ -12,14 +12,14 @@
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
#include <torch/csrc/jit/api/compilation_unit.h>
|
|
||||||
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
|
||||||
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
|
||||||
#include <torch/csrc/lazy/core/config.h>
|
|
||||||
#include "torch-mlir-c/Registration.h"
|
|
||||||
#include "torch-mlir-c/Transforms.h"
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Pass.h"
|
#include "mlir-c/Pass.h"
|
||||||
|
#include "torch-mlir-c/Registration.h"
|
||||||
|
#include "torch-mlir-c/Transforms.h"
|
||||||
|
#include <torch/csrc/jit/api/compilation_unit.h>
|
||||||
|
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
||||||
|
#include <torch/csrc/lazy/core/config.h>
|
||||||
|
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
|
||||||
|
|
||||||
#include "backend_impl.h"
|
#include "backend_impl.h"
|
||||||
#include "jit_ir_importer/function_importer.h"
|
#include "jit_ir_importer/function_importer.h"
|
||||||
|
@ -38,8 +38,8 @@ namespace lazy {
|
||||||
// TorchMlir Lowering Context
|
// TorchMlir Lowering Context
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
TorchMlirLoweringContext::TorchMlirLoweringContext(const std::string &name,
|
||||||
const std::string& name, BackendDevice device)
|
BackendDevice device)
|
||||||
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
: LoweringContext(name, std::forward<BackendDevice>(device)),
|
||||||
graph_(std::make_shared<torch::jit::Graph>()),
|
graph_(std::make_shared<torch::jit::Graph>()),
|
||||||
function_(
|
function_(
|
||||||
|
@ -49,11 +49,12 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
const std::string& name, BackendDevice device,
|
const std::string &name, BackendDevice device,
|
||||||
c10::ArrayRef<const torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
|
c10::ArrayRef<const torch::lazy::Node *> post_order,
|
||||||
|
Util::EmissionMap emit_status)
|
||||||
: LoweringContext(
|
: LoweringContext(
|
||||||
name, std::forward<BackendDevice>(device),
|
name, std::forward<BackendDevice>(device),
|
||||||
std::forward<c10::ArrayRef<const torch::lazy::Node*>>(post_order),
|
std::forward<c10::ArrayRef<const torch::lazy::Node *>>(post_order),
|
||||||
std::forward<Util::EmissionMap>(emit_status)),
|
std::forward<Util::EmissionMap>(emit_status)),
|
||||||
graph_(std::make_shared<torch::jit::Graph>()),
|
graph_(std::make_shared<torch::jit::Graph>()),
|
||||||
function_(
|
function_(
|
||||||
|
@ -66,9 +67,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchMlirLoweringContext::Lower(const Node* node) {
|
void TorchMlirLoweringContext::Lower(const Node *node) {
|
||||||
if (auto* torch_mlir_node =
|
if (auto *torch_mlir_node =
|
||||||
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
|
dynamic_cast<const torch::lazy::TorchMlirNode *>(node)) {
|
||||||
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
|
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
|
||||||
CHECK(!ops.empty()) << "Failed to lower: " << *node;
|
CHECK(!ops.empty()) << "Failed to lower: " << *node;
|
||||||
TORCH_CHECK_EQ(node->num_outputs(), ops.size());
|
TORCH_CHECK_EQ(node->num_outputs(), ops.size());
|
||||||
|
@ -82,19 +83,19 @@ void TorchMlirLoweringContext::Lower(const Node* node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchMlirLoweringContext::SetUpAlias(
|
void TorchMlirLoweringContext::SetUpAlias(
|
||||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
const std::vector<int64_t> &output_index, int64_t param_number,
|
||||||
const std::vector<int64_t>& param_index, bool must_alias) {
|
const std::vector<int64_t> ¶m_index, bool must_alias) {
|
||||||
input_output_aliases_.push_back(
|
input_output_aliases_.push_back(
|
||||||
{output_index, param_number, param_index, must_alias});
|
{output_index, param_number, param_index, must_alias});
|
||||||
}
|
}
|
||||||
|
|
||||||
bool TorchMlirLoweringContext::CheckResultShape(
|
bool TorchMlirLoweringContext::CheckResultShape(
|
||||||
const BackendDataPtr& parameter_data, size_t result_idx) {
|
const BackendDataPtr ¶meter_data, size_t result_idx) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(result_idx < root_tuple_.size(),
|
||||||
result_idx < root_tuple_.size(), "Tried getting result shape at index ",
|
"Tried getting result shape at index ", result_idx,
|
||||||
result_idx, " which is out of bounds!");
|
" which is out of bounds!");
|
||||||
|
|
||||||
torch::jit::Value* output = root_tuple_[result_idx];
|
torch::jit::Value *output = root_tuple_[result_idx];
|
||||||
|
|
||||||
if (c10::TensorTypePtr tensor_type =
|
if (c10::TensorTypePtr tensor_type =
|
||||||
output->type()->cast<c10::TensorType>()) {
|
output->type()->cast<c10::TensorType>()) {
|
||||||
|
@ -111,7 +112,7 @@ bool TorchMlirLoweringContext::CheckResultShape(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TorchMlirLoweringContext::AddResult(const Output& output) {
|
size_t TorchMlirLoweringContext::AddResult(const Output &output) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
return AddResult(GetOutputOp(output));
|
return AddResult(GetOutputOp(output));
|
||||||
|
@ -120,9 +121,10 @@ size_t TorchMlirLoweringContext::AddResult(const Output& output) {
|
||||||
// Associates the given output with the input parameter of the given index and
|
// Associates the given output with the input parameter of the given index and
|
||||||
// shape. Only used for the operator-by-operator execution, mostly for
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
// debugging purposes.
|
// debugging purposes.
|
||||||
void TorchMlirLoweringContext::AddParameter(
|
void TorchMlirLoweringContext::AddParameter(const torch::lazy::Output &output,
|
||||||
const torch::lazy::Output& output, size_t index,
|
size_t index,
|
||||||
const torch::lazy::Shape& shape, const std::string& name) {
|
const torch::lazy::Shape &shape,
|
||||||
|
const std::string &name) {
|
||||||
UNIMPLEMENTED_FUNCTION_ERROR();
|
UNIMPLEMENTED_FUNCTION_ERROR();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -136,7 +138,7 @@ ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
torch::jit::RefineTupleTypes(graph_);
|
torch::jit::RefineTupleTypes(graph_);
|
||||||
|
|
||||||
// Insert return values into graph.
|
// Insert return values into graph.
|
||||||
for (torch::jit::Value* output : root_tuple_) {
|
for (torch::jit::Value *output : root_tuple_) {
|
||||||
graph_->block()->registerOutput(output);
|
graph_->block()->registerOutput(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +154,6 @@ ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
|
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
|
||||||
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});
|
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});
|
||||||
|
|
||||||
|
|
||||||
// Convert MlirOperation to MlirModule.
|
// Convert MlirOperation to MlirModule.
|
||||||
MlirLocation loc = mlirLocationUnknownGet(mlir_context_);
|
MlirLocation loc = mlirLocationUnknownGet(mlir_context_);
|
||||||
MlirModule module_op = mlirModuleCreateEmpty(loc);
|
MlirModule module_op = mlirModuleCreateEmpty(loc);
|
||||||
|
@ -162,14 +163,10 @@ ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
// Apply passes to verify generated MLIR.
|
// Apply passes to verify generated MLIR.
|
||||||
auto pass_manager = mlirPassManagerCreate(mlir_context_);
|
auto pass_manager = mlirPassManagerCreate(mlir_context_);
|
||||||
mlirPassManagerAddOwnedPass(
|
mlirPassManagerAddOwnedPass(
|
||||||
pass_manager,
|
pass_manager, mlirCreateVerifyBackendContractNoDecompositions());
|
||||||
mlirCreateVerifyBackendContractNoDecompositions()
|
|
||||||
);
|
|
||||||
|
|
||||||
MlirLogicalResult result = mlirPassManagerRunOnOp(
|
MlirLogicalResult result =
|
||||||
pass_manager,
|
mlirPassManagerRunOnOp(pass_manager, mlirModuleGetOperation(module_op));
|
||||||
mlirModuleGetOperation(module_op)
|
|
||||||
);
|
|
||||||
|
|
||||||
if (mlirLogicalResultIsFailure(result)) {
|
if (mlirLogicalResultIsFailure(result)) {
|
||||||
throw std::runtime_error("MLIR verification has failed.");
|
throw std::runtime_error("MLIR verification has failed.");
|
||||||
|
@ -178,12 +175,14 @@ ComputationPtr TorchMlirLoweringContext::Build() {
|
||||||
return CreateComputation(module_op);
|
return CreateComputation(module_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) {
|
ComputationPtr
|
||||||
return std::make_shared<TorchMlirComputation>(
|
TorchMlirLoweringContext::CreateComputation(MlirModule module_op) {
|
||||||
module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_);
|
return std::make_shared<TorchMlirComputation>(module_op, mlir_context_,
|
||||||
|
graph_, parameter_names_,
|
||||||
|
input_output_aliases_);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
|
torch::jit::Value *TorchMlirLoweringContext::GetOutputOp(const Output &output) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
auto it = emitted_outputs_.find(output);
|
auto it = emitted_outputs_.find(output);
|
||||||
|
@ -195,15 +194,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
|
||||||
// At this point the output better be present, otherwise there is an issue
|
// At this point the output better be present, otherwise there is an issue
|
||||||
// with the lowering code.
|
// with the lowering code.
|
||||||
it = emitted_outputs_.find(output);
|
it = emitted_outputs_.find(output);
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(it != emitted_outputs_.end(),
|
||||||
it != emitted_outputs_.end(),
|
"No MLIR operation emitted for output: ", output.ToString());
|
||||||
"No MLIR operation emitted for output: ", output.ToString());
|
|
||||||
}
|
}
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TorchMlirLoweringContext::AssignOutputOp(
|
void TorchMlirLoweringContext::AssignOutputOp(const Output &output,
|
||||||
const Output& output, torch::jit::Value* op) {
|
torch::jit::Value *op) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
auto torch_mlir_node =
|
auto torch_mlir_node =
|
||||||
|
@ -211,48 +209,44 @@ void TorchMlirLoweringContext::AssignOutputOp(
|
||||||
|
|
||||||
std::vector<std::string> source_files, functions;
|
std::vector<std::string> source_files, functions;
|
||||||
std::vector<int64_t> line_numbers;
|
std::vector<int64_t> line_numbers;
|
||||||
const auto& metadata = torch_mlir_node->metadata();
|
const auto &metadata = torch_mlir_node->metadata();
|
||||||
const auto& frames = metadata.frame_info;
|
const auto &frames = metadata.frame_info;
|
||||||
if (!frames.empty()) {
|
if (!frames.empty()) {
|
||||||
static std::vector<std::string> g_roots =
|
static std::vector<std::string> g_roots =
|
||||||
string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":");
|
string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":");
|
||||||
|
|
||||||
std::for_each(frames.rbegin(), frames.rend(),
|
std::for_each(frames.rbegin(), frames.rend(),
|
||||||
[&](const torch::lazy::SourceLocation& location) {
|
[&](const torch::lazy::SourceLocation &location) {
|
||||||
functions.push_back(location.function);
|
functions.push_back(location.function);
|
||||||
line_numbers.push_back(location.line);
|
line_numbers.push_back(location.line);
|
||||||
|
|
||||||
std::string file_name = location.file;
|
std::string file_name = location.file;
|
||||||
for (const std::string& root : g_roots) {
|
for (const std::string &root : g_roots) {
|
||||||
if (startswith(file_name, root)) {
|
if (startswith(file_name, root)) {
|
||||||
// location.file starts with root, strip it off
|
// location.file starts with root, strip it off
|
||||||
file_name = file_name.substr(root.size());
|
file_name = file_name.substr(root.size());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
source_files.push_back(file_name);
|
source_files.push_back(file_name);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!source_files.empty()) {
|
if (!source_files.empty()) {
|
||||||
op->node()->ss_(
|
op->node()->ss_(c10::Symbol::attr("source_files"), source_files);
|
||||||
c10::Symbol::attr("source_files"), source_files);
|
op->node()->ss_(c10::Symbol::attr("functions"), functions);
|
||||||
op->node()->ss_(
|
op->node()->is_(c10::Symbol::attr("line_numbers"), line_numbers);
|
||||||
c10::Symbol::attr("functions"), functions);
|
|
||||||
op->node()->is_(
|
|
||||||
c10::Symbol::attr("line_numbers"), line_numbers);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto scope = ::c10::Symbol::scope(metadata.scope);
|
auto scope = ::c10::Symbol::scope(metadata.scope);
|
||||||
op->node()->setScope(
|
op->node()->setScope(c10::make_intrusive<torch::jit::Scope>()->push(scope));
|
||||||
c10::make_intrusive<torch::jit::Scope>()->push(scope));
|
|
||||||
|
|
||||||
emitted_outputs_[output] = std::move(op);
|
emitted_outputs_[output] = std::move(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
|
torch::jit::Value *TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
|
|
||||||
if (!dynamic_cast<TorchMlirBackendData*>(data.get())) {
|
if (!dynamic_cast<TorchMlirBackendData *>(data.get())) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
false,
|
false,
|
||||||
"Expected TorchMlirBackendData. Got some other BackendData type");
|
"Expected TorchMlirBackendData. Got some other BackendData type");
|
||||||
|
@ -263,20 +257,21 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
|
||||||
auto it = parameters_map_.find(handle);
|
auto it = parameters_map_.find(handle);
|
||||||
|
|
||||||
if (it == parameters_map_.end()) {
|
if (it == parameters_map_.end()) {
|
||||||
torch::jit::Value* param =
|
torch::jit::Value *param =
|
||||||
graph_->addInput(c10::str("p", parameters_.size()));
|
graph_->addInput(c10::str("p", parameters_.size()));
|
||||||
|
|
||||||
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
|
auto *info =
|
||||||
|
dynamic_cast<TorchMlirBackendData::Info *>(mlir_data->mlir_info());
|
||||||
TORCH_CHECK(info, "Expected TorchMlirBackendData::Info");
|
TORCH_CHECK(info, "Expected TorchMlirBackendData::Info");
|
||||||
if (info->scalar.has_value()) {
|
if (info->scalar.has_value()) {
|
||||||
auto& scalar = info->scalar.value();
|
auto &scalar = info->scalar.value();
|
||||||
if (scalar.isFloatingPoint()) {
|
if (scalar.isFloatingPoint()) {
|
||||||
param->setType(c10::FloatType::get());
|
param->setType(c10::FloatType::get());
|
||||||
} else if (scalar.isIntegral(true)) {
|
} else if (scalar.isIntegral(true)) {
|
||||||
param->setType(c10::IntType::get());
|
param->setType(c10::IntType::get());
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(false,
|
||||||
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
|
"Unhandled scalar type: ", c10::toString(scalar.type()));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Save parameter shape information.
|
// Save parameter shape information.
|
||||||
|
@ -305,7 +300,7 @@ std::shared_ptr<torch::jit::Graph> TorchMlirLoweringContext::graph() const {
|
||||||
return graph_;
|
return graph_;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
|
size_t TorchMlirLoweringContext::AddResult(torch::jit::Value *op) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
root_tuple_.push_back(std::move(op));
|
root_tuple_.push_back(std::move(op));
|
||||||
return root_tuple_.size() - 1;
|
return root_tuple_.size() - 1;
|
||||||
|
@ -313,9 +308,9 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
|
||||||
|
|
||||||
// Sync vector of c10::Argument with type specified from parallel list of
|
// Sync vector of c10::Argument with type specified from parallel list of
|
||||||
// jit::Value. There must be a 1:1 map between elements of args and values.
|
// jit::Value. There must be a 1:1 map between elements of args and values.
|
||||||
std::vector<c10::Argument> sync_argument_types(
|
std::vector<c10::Argument>
|
||||||
const std::vector<c10::Argument>& args,
|
sync_argument_types(const std::vector<c10::Argument> &args,
|
||||||
c10::ArrayRef<torch::jit::Value*> values) {
|
c10::ArrayRef<torch::jit::Value *> values) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
args.size() == values.size(),
|
args.size() == values.size(),
|
||||||
"Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ",
|
"Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ",
|
||||||
|
@ -362,7 +357,7 @@ void TorchMlirLoweringContext::RegisterMlirDialects() {
|
||||||
|
|
||||||
TorchMlirComputation::TorchMlirComputation(
|
TorchMlirComputation::TorchMlirComputation(
|
||||||
MlirModule module_op, MlirContext mlir_context,
|
MlirModule module_op, MlirContext mlir_context,
|
||||||
const std::shared_ptr<torch::jit::Graph>& graph,
|
const std::shared_ptr<torch::jit::Graph> &graph,
|
||||||
std::unordered_map<int, std::string> parameters_map,
|
std::unordered_map<int, std::string> parameters_map,
|
||||||
InputOutputAliases input_output_aliases)
|
InputOutputAliases input_output_aliases)
|
||||||
: module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)),
|
: module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)),
|
||||||
|
@ -377,26 +372,25 @@ TorchMlirComputation::TorchMlirComputation(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int TorchMlirComputation::parameters_size() const {
|
int TorchMlirComputation::parameters_size() const { return num_parameters_; }
|
||||||
return num_parameters_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<torch::lazy::Shape>&
|
const std::vector<torch::lazy::Shape> &
|
||||||
TorchMlirComputation::parameter_shapes() const {
|
TorchMlirComputation::parameter_shapes() const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"todo(whc) implement ts computation shapes or change interface");
|
"todo(whc) implement ts computation shapes or change interface");
|
||||||
return parameter_shapes_;
|
return parameter_shapes_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
|
const std::vector<std::string> &TorchMlirComputation::parameter_names() const {
|
||||||
return parameter_names_;
|
return parameter_names_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::unordered_map<int, std::string>& TorchMlirComputation::parameters_map() const {
|
const std::unordered_map<int, std::string> &
|
||||||
|
TorchMlirComputation::parameters_map() const {
|
||||||
return parameters_map_;
|
return parameters_map_;
|
||||||
}
|
}
|
||||||
|
|
||||||
const torch::lazy::Shape& TorchMlirComputation::result_shape() const {
|
const torch::lazy::Shape &TorchMlirComputation::result_shape() const {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"todo(whc) implement ts computation shapes or change interface");
|
"todo(whc) implement ts computation shapes or change interface");
|
||||||
return result_shape_;
|
return result_shape_;
|
||||||
|
@ -411,13 +405,9 @@ MlirOperation TorchMlirComputation::func_op() const {
|
||||||
return mlirBlockGetFirstOperation(block);
|
return mlirBlockGetFirstOperation(block);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirModule TorchMlirComputation::module_op() const {
|
MlirModule TorchMlirComputation::module_op() const { return module_op_; }
|
||||||
return module_op_;
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirContext TorchMlirComputation::mlir_context() const {
|
MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; }
|
||||||
return mlir_context_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::string TorchMlirComputation::debug_string() const {
|
const std::string TorchMlirComputation::debug_string() const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
@ -430,7 +420,7 @@ const std::string TorchMlirComputation::debug_string() const {
|
||||||
|
|
||||||
// Parameter names
|
// Parameter names
|
||||||
ss << "Parameter names:\n";
|
ss << "Parameter names:\n";
|
||||||
for (auto& p : parameter_names_) {
|
for (auto &p : parameter_names_) {
|
||||||
ss << " " << p << "\n";
|
ss << " " << p << "\n";
|
||||||
}
|
}
|
||||||
ss << "\n";
|
ss << "\n";
|
||||||
|
@ -451,10 +441,10 @@ const std::string TorchMlirComputation::debug_string() const {
|
||||||
|
|
||||||
const std::string TorchMlirComputation::to_string() const {
|
const std::string TorchMlirComputation::to_string() const {
|
||||||
// Since we use the C-MLIR API, we need to use a callback to print.
|
// Since we use the C-MLIR API, we need to use a callback to print.
|
||||||
MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) {
|
MlirStringCallback print_callback = [](MlirStringRef part, void *user_data) {
|
||||||
// user_data is a void ptr to some data structure of our choice -- in this
|
// user_data is a void ptr to some data structure of our choice -- in this
|
||||||
// case, the string stream where we'll be accumulating the strings.
|
// case, the string stream where we'll be accumulating the strings.
|
||||||
std::stringstream* ss_ptr = static_cast<std::stringstream*>(user_data);
|
std::stringstream *ss_ptr = static_cast<std::stringstream *>(user_data);
|
||||||
*ss_ptr << std::string(part.data, part.length);
|
*ss_ptr << std::string(part.data, part.length);
|
||||||
};
|
};
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
|
@ -462,7 +452,8 @@ const std::string TorchMlirComputation::to_string() const {
|
||||||
// Setup flags for MLIR serialization.
|
// Setup flags for MLIR serialization.
|
||||||
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
|
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
|
||||||
mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false);
|
mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false);
|
||||||
mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss);
|
mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags,
|
||||||
|
print_callback, &ss);
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -39,35 +39,34 @@ public:
|
||||||
};
|
};
|
||||||
using InputOutputAliases = std::vector<InputOutputAlias>;
|
using InputOutputAliases = std::vector<InputOutputAlias>;
|
||||||
|
|
||||||
TorchMlirLoweringContext(
|
TorchMlirLoweringContext(const std::string &name,
|
||||||
const std::string& name, torch::lazy::BackendDevice device);
|
torch::lazy::BackendDevice device);
|
||||||
TorchMlirLoweringContext(
|
TorchMlirLoweringContext(const std::string &name,
|
||||||
const std::string& name, torch::lazy::BackendDevice device,
|
torch::lazy::BackendDevice device,
|
||||||
c10::ArrayRef<const torch::lazy::Node*> post_order,
|
c10::ArrayRef<const torch::lazy::Node *> post_order,
|
||||||
torch::lazy::Util::EmissionMap emit_status);
|
torch::lazy::Util::EmissionMap emit_status);
|
||||||
|
|
||||||
void Lower(const Node* node);
|
void Lower(const Node *node);
|
||||||
|
|
||||||
// Adds a new input/output alias.
|
// Adds a new input/output alias.
|
||||||
void SetUpAlias(
|
void SetUpAlias(const std::vector<int64_t> &output_index,
|
||||||
const std::vector<int64_t>& output_index, int64_t param_number,
|
int64_t param_number, const std::vector<int64_t> ¶m_index,
|
||||||
const std::vector<int64_t>& param_index,
|
bool must_alias = false) override;
|
||||||
bool must_alias = false) override;
|
|
||||||
|
|
||||||
// Check if parameter shape matches result at index.
|
// Check if parameter shape matches result at index.
|
||||||
bool CheckResultShape(
|
bool CheckResultShape(const BackendDataPtr ¶meter_data,
|
||||||
const BackendDataPtr& parameter_data, size_t result_idx) override;
|
size_t result_idx) override;
|
||||||
|
|
||||||
// Adds the given output as a component of the result tuple and returns its
|
// Adds the given output as a component of the result tuple and returns its
|
||||||
// assigned position within the tuple.
|
// assigned position within the tuple.
|
||||||
size_t AddResult(const torch::lazy::Output& output) override;
|
size_t AddResult(const torch::lazy::Output &output) override;
|
||||||
|
|
||||||
// Associates the given output with the input parameter of the given index and
|
// Associates the given output with the input parameter of the given index and
|
||||||
// shape. Only used for the operator-by-operator execution, mostly for
|
// shape. Only used for the operator-by-operator execution, mostly for
|
||||||
// debugging purposes.
|
// debugging purposes.
|
||||||
void AddParameter(
|
void AddParameter(const torch::lazy::Output &output, size_t index,
|
||||||
const torch::lazy::Output& output, size_t index,
|
const torch::lazy::Shape &shape,
|
||||||
const torch::lazy::Shape& shape, const std::string& name) override;
|
const std::string &name) override;
|
||||||
|
|
||||||
// Build the computation capturing all the operations created with the
|
// Build the computation capturing all the operations created with the
|
||||||
// embedded builder (returned by the builder() API).
|
// embedded builder (returned by the builder() API).
|
||||||
|
@ -78,27 +77,27 @@ public:
|
||||||
// Retrieves the lowered operation for an output. If the requested output is
|
// Retrieves the lowered operation for an output. If the requested output is
|
||||||
// not available yet, the graph behind the output's Node is lowered, and the
|
// not available yet, the graph behind the output's Node is lowered, and the
|
||||||
// corresponding TS operation returned.
|
// corresponding TS operation returned.
|
||||||
torch::jit::Value* GetOutputOp(const Output& output);
|
torch::jit::Value *GetOutputOp(const Output &output);
|
||||||
|
|
||||||
// Assigns the given TS operation to the specified output. As outputs are
|
// Assigns the given TS operation to the specified output. As outputs are
|
||||||
// lowered in a post-order fashion, later nodes should always find their
|
// lowered in a post-order fashion, later nodes should always find their
|
||||||
// operands among the emitted outputs.
|
// operands among the emitted outputs.
|
||||||
void AssignOutputOp(const Output& output, torch::jit::Value* op);
|
void AssignOutputOp(const Output &output, torch::jit::Value *op);
|
||||||
|
|
||||||
// If a parameter associated with data has already been declared, it will be
|
// If a parameter associated with data has already been declared, it will be
|
||||||
// returned. Otherwise a new one will be created, associated with the tensor
|
// returned. Otherwise a new one will be created, associated with the tensor
|
||||||
// held in data.
|
// held in data.
|
||||||
torch::jit::Value* GetParameter(BackendDataPtr data);
|
torch::jit::Value *GetParameter(BackendDataPtr data);
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::Graph> graph() const;
|
std::shared_ptr<torch::jit::Graph> graph() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
struct Parameter {
|
struct Parameter {
|
||||||
torch::jit::Value* param;
|
torch::jit::Value *param;
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t AddResult(torch::jit::Value* op);
|
size_t AddResult(torch::jit::Value *op);
|
||||||
|
|
||||||
// Creates a jit::Function from the current jit::Graph. Input and output
|
// Creates a jit::Function from the current jit::Graph. Input and output
|
||||||
// type information is patched to include shape.
|
// type information is patched to include shape.
|
||||||
|
@ -113,8 +112,8 @@ protected:
|
||||||
MlirContext mlir_context_;
|
MlirContext mlir_context_;
|
||||||
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
std::unordered_map<BackendData::Handle, Parameter> parameters_map_;
|
||||||
std::unordered_map<int, std::string> parameter_names_;
|
std::unordered_map<int, std::string> parameter_names_;
|
||||||
std::vector<torch::jit::Value*> root_tuple_;
|
std::vector<torch::jit::Value *> root_tuple_;
|
||||||
OutputMap<torch::jit::Value*> emitted_outputs_;
|
OutputMap<torch::jit::Value *> emitted_outputs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
class TORCH_API TorchMlirComputation : public torch::lazy::Computation {
|
||||||
|
@ -122,21 +121,20 @@ public:
|
||||||
using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases;
|
using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases;
|
||||||
using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias;
|
using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias;
|
||||||
|
|
||||||
TorchMlirComputation(
|
TorchMlirComputation(MlirModule module_op, MlirContext mlir_context,
|
||||||
MlirModule module_op, MlirContext mlir_context,
|
const std::shared_ptr<torch::jit::Graph> &graph,
|
||||||
const std::shared_ptr<torch::jit::Graph>& graph,
|
std::unordered_map<int, std::string> parameters_map,
|
||||||
std::unordered_map<int, std::string> parameters_map,
|
InputOutputAliases input_output_aliases);
|
||||||
InputOutputAliases input_output_aliases);
|
|
||||||
|
|
||||||
int parameters_size() const override;
|
int parameters_size() const override;
|
||||||
|
|
||||||
const std::vector<torch::lazy::Shape>& parameter_shapes() const override;
|
const std::vector<torch::lazy::Shape> ¶meter_shapes() const override;
|
||||||
|
|
||||||
const std::vector<std::string>& parameter_names() const override;
|
const std::vector<std::string> ¶meter_names() const override;
|
||||||
|
|
||||||
const std::unordered_map<int, std::string>& parameters_map() const;
|
const std::unordered_map<int, std::string> ¶meters_map() const;
|
||||||
|
|
||||||
const torch::lazy::Shape& result_shape() const override;
|
const torch::lazy::Shape &result_shape() const override;
|
||||||
|
|
||||||
std::shared_ptr<torch::jit::Graph> graph() const;
|
std::shared_ptr<torch::jit::Graph> graph() const;
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,8 @@
|
||||||
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
|
||||||
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
#include <ATen/CompositeExplicitAutogradFunctions.h>
|
||||||
|
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
|
||||||
#include <ATen/FunctionalTensorWrapper.h>
|
#include <ATen/FunctionalTensorWrapper.h>
|
||||||
#include <ATen/InferSize.h>
|
#include <ATen/InferSize.h>
|
||||||
#include <ATen/MetaFunctions.h>
|
#include <ATen/MetaFunctions.h>
|
||||||
|
@ -33,16 +33,16 @@
|
||||||
#include "generated/LazyIr.h"
|
#include "generated/LazyIr.h"
|
||||||
#include "generated/LazyNativeFunctions.h"
|
#include "generated/LazyNativeFunctions.h"
|
||||||
#include "generated/shape_inference.h"
|
#include "generated/shape_inference.h"
|
||||||
#include "ops/to_copy.h"
|
|
||||||
#include "ops/unbind_int.h"
|
|
||||||
#include "ops/split.h"
|
|
||||||
#include "ops/index.h"
|
#include "ops/index.h"
|
||||||
#include "ops/ivalue.h"
|
#include "ops/ivalue.h"
|
||||||
|
#include "ops/split.h"
|
||||||
|
#include "ops/to_copy.h"
|
||||||
|
#include "ops/unbind_int.h"
|
||||||
#include "utils/exception.h"
|
#include "utils/exception.h"
|
||||||
#include "utils/sys_utils.h"
|
#include "utils/sys_utils.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
at::Tensor to_meta(const at::Tensor& tensor) {
|
at::Tensor to_meta(const at::Tensor &tensor) {
|
||||||
// undefined tensors can't be converted to the meta device, since they don't
|
// undefined tensors can't be converted to the meta device, since they don't
|
||||||
// have sizes/strides
|
// have sizes/strides
|
||||||
if (!tensor.defined())
|
if (!tensor.defined())
|
||||||
|
@ -60,7 +60,7 @@ at::Tensor to_meta(const at::Tensor& tensor) {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
|
c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor> &tensor) {
|
||||||
if (tensor.has_value()) {
|
if (tensor.has_value()) {
|
||||||
return to_meta(*tensor);
|
return to_meta(*tensor);
|
||||||
}
|
}
|
||||||
|
@ -70,16 +70,17 @@ c10::optional<at::Tensor> to_meta(const c10::optional<at::Tensor>& tensor) {
|
||||||
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||||
std::vector<at::Tensor> outs;
|
std::vector<at::Tensor> outs;
|
||||||
outs.reserve(t_list.size());
|
outs.reserve(t_list.size());
|
||||||
for (const auto& tensor : t_list) {
|
for (const auto &tensor : t_list) {
|
||||||
outs.push_back(to_meta(tensor));
|
outs.push_back(to_meta(tensor));
|
||||||
}
|
}
|
||||||
return outs;
|
return outs;
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::List<c10::optional<at::Tensor>> to_meta(const c10::List<c10::optional<at::Tensor>>& t_list) {
|
c10::List<c10::optional<at::Tensor>>
|
||||||
|
to_meta(const c10::List<c10::optional<at::Tensor>> &t_list) {
|
||||||
c10::List<c10::optional<at::Tensor>> outs;
|
c10::List<c10::optional<at::Tensor>> outs;
|
||||||
outs.reserve(t_list.size());
|
outs.reserve(t_list.size());
|
||||||
for (const auto& tensor : t_list) {
|
for (const auto &tensor : t_list) {
|
||||||
outs.push_back(to_meta(tensor));
|
outs.push_back(to_meta(tensor));
|
||||||
}
|
}
|
||||||
return outs;
|
return outs;
|
||||||
|
@ -91,9 +92,9 @@ namespace lazy {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
at::Tensor CreateLtcTensor(
|
at::Tensor
|
||||||
const at::Tensor& tensor,
|
CreateLtcTensor(const at::Tensor &tensor,
|
||||||
const c10::optional<torch::lazy::BackendDevice>& device) {
|
const c10::optional<torch::lazy::BackendDevice> &device) {
|
||||||
if (tensor.defined() && device) {
|
if (tensor.defined() && device) {
|
||||||
return torch::lazy::CreateAtenFromLtcTensor(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
torch::lazy::LazyTensor::Create(tensor, *device));
|
torch::lazy::LazyTensor::Create(tensor, *device));
|
||||||
|
@ -102,7 +103,7 @@ at::Tensor CreateLtcTensor(
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<torch::lazy::BackendDevice>
|
c10::optional<torch::lazy::BackendDevice>
|
||||||
GetLtcDevice(const c10::optional<c10::Device>& device) {
|
GetLtcDevice(const c10::optional<c10::Device> &device) {
|
||||||
if (!device) {
|
if (!device) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -112,24 +113,23 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
|
||||||
return torch::lazy::atenDeviceToBackendDevice(*device);
|
return torch::lazy::atenDeviceToBackendDevice(*device);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::lazy::Value MaybeExpand(
|
torch::lazy::Value MaybeExpand(const torch::lazy::Value &input,
|
||||||
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
|
const torch::lazy::Shape &target_shape) {
|
||||||
if (input.shape().sizes() == target_shape.sizes()) {
|
if (input.shape().sizes() == target_shape.sizes()) {
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
return torch::lazy::MakeExpand(
|
return torch::lazy::MakeExpand(input, target_shape.sizes().vec(),
|
||||||
input, target_shape.sizes().vec(),
|
/*is_scalar_expand=*/false);
|
||||||
/*is_scalar_expand=*/false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
void copy_(torch::lazy::LazyTensorPtr &input, torch::lazy::LazyTensorPtr &src) {
|
||||||
if (input->GetDevice() == src->GetDevice()) {
|
if (input->GetDevice() == src->GetDevice()) {
|
||||||
torch::lazy::Value copy_value;
|
torch::lazy::Value copy_value;
|
||||||
if (input->dtype() == src->dtype()) {
|
if (input->dtype() == src->dtype()) {
|
||||||
copy_value = src->GetIrValue();
|
copy_value = src->GetIrValue();
|
||||||
} else {
|
} else {
|
||||||
copy_value = torch::lazy::MakeCast(
|
copy_value = torch::lazy::MakeCast(src->GetIrValue(), input->dtype(),
|
||||||
src->GetIrValue(), input->dtype(), src->dtype());
|
src->dtype());
|
||||||
}
|
}
|
||||||
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
|
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
|
||||||
} else {
|
} else {
|
||||||
|
@ -146,15 +146,17 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
|
||||||
|
|
||||||
// clone is special in LT because we make it a no-op.
|
// clone is special in LT because we make it a no-op.
|
||||||
// This should be safe to do, because every operator in the LT is functional.
|
// This should be safe to do, because every operator in the LT is functional.
|
||||||
at::Tensor LazyNativeFunctions::clone(
|
at::Tensor
|
||||||
const at::Tensor& self, c10::optional<at::MemoryFormat> memory_format) {
|
LazyNativeFunctions::clone(const at::Tensor &self,
|
||||||
|
c10::optional<at::MemoryFormat> memory_format) {
|
||||||
auto self_lt = torch::lazy::TryGetLtcTensor(self);
|
auto self_lt = torch::lazy::TryGetLtcTensor(self);
|
||||||
return torch::lazy::CreateAtenFromLtcTensor(
|
return torch::lazy::CreateAtenFromLtcTensor(
|
||||||
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
|
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::_copy_from(
|
at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor &self,
|
||||||
const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
|
const at::Tensor &dst,
|
||||||
|
bool non_blocking) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
||||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
@ -199,16 +201,16 @@ at::Tensor LazyNativeFunctions::_copy_from(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
copy_(dst_tensor, self_tensor);
|
copy_(dst_tensor, self_tensor);
|
||||||
auto* impl =
|
auto *impl =
|
||||||
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
dynamic_cast<torch::lazy::LTCTensorImpl *>(dst.unsafeGetTensorImpl());
|
||||||
impl->set_tensor(dst_tensor);
|
impl->set_tensor(dst_tensor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return dst;
|
return dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::_copy_from_and_resize(
|
at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self,
|
||||||
const at::Tensor& self, const at::Tensor& dst) {
|
const at::Tensor &dst) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
|
||||||
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
|
||||||
|
@ -223,8 +225,8 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(
|
||||||
dst.resize_as_(typed_tensor).copy_(typed_tensor);
|
dst.resize_as_(typed_tensor).copy_(typed_tensor);
|
||||||
} else {
|
} else {
|
||||||
// at this point we know dst is a lazy tensor
|
// at this point we know dst is a lazy tensor
|
||||||
auto* dest_impl =
|
auto *dest_impl =
|
||||||
dynamic_cast<torch::lazy::LTCTensorImpl*>(dst.unsafeGetTensorImpl());
|
dynamic_cast<torch::lazy::LTCTensorImpl *>(dst.unsafeGetTensorImpl());
|
||||||
dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
|
dest_impl->tensor()->UpdateFromTensorOut(self_tensor);
|
||||||
dest_impl->force_refresh_sizes();
|
dest_impl->force_refresh_sizes();
|
||||||
}
|
}
|
||||||
|
@ -232,15 +234,16 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize(
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::_to_copy(
|
at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
const at::Tensor& self, c10::optional<at::ScalarType> dtype,
|
const at::Tensor &self, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
c10::optional<bool> pin_memory, bool non_blocking,
|
c10::optional<bool> pin_memory, bool non_blocking,
|
||||||
c10::optional<at::MemoryFormat> memory_format) {
|
c10::optional<at::MemoryFormat> memory_format) {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
auto options = self.options();
|
auto options = self.options();
|
||||||
if (dtype) {
|
if (dtype) {
|
||||||
// I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)...
|
// I put each of these setters in a conditional instead of doing
|
||||||
// because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it
|
// `self.options().dtype(dtype).layout(layout)... because calling
|
||||||
|
// .dtype(nullopt) on an options() that already has dtype appears to wipe it
|
||||||
options = options.dtype(dtype);
|
options = options.dtype(dtype);
|
||||||
}
|
}
|
||||||
if (layout) {
|
if (layout) {
|
||||||
|
@ -261,8 +264,9 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
if (!lazy_self && device && device->type() == c10::kLazy) {
|
if (!lazy_self && device && device->type() == c10::kLazy) {
|
||||||
// Case 1: eager->lazy (we create a new lazy tensor)
|
// Case 1: eager->lazy (we create a new lazy tensor)
|
||||||
// See Note [Lazy Tensor Functionalization]
|
// See Note [Lazy Tensor Functionalization]
|
||||||
// Invariant: if the functionalization key is in the exclude set, then we're expected
|
// Invariant: if the functionalization key is in the exclude set, then we're
|
||||||
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
|
// expected to return an ordinary tensor, which will be "lifted" into a
|
||||||
|
// functional wrapper later.
|
||||||
bool functionalize_output =
|
bool functionalize_output =
|
||||||
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
|
||||||
c10::DispatchKey::Functionalize);
|
c10::DispatchKey::Functionalize);
|
||||||
|
@ -270,7 +274,8 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
self, options, *device, /*non_blocking=*/non_blocking,
|
self, options, *device, /*non_blocking=*/non_blocking,
|
||||||
/*functionalize_output=*/functionalize_output);
|
/*functionalize_output=*/functionalize_output);
|
||||||
} else if (device && device->type() != c10::kLazy) {
|
} else if (device && device->type() != c10::kLazy) {
|
||||||
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
|
// Case 2: lazy->eager (forces a graph break since we are materializing a
|
||||||
|
// tensor)
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(lazy_self);
|
TORCH_INTERNAL_ASSERT(lazy_self);
|
||||||
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
||||||
|
@ -278,22 +283,24 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
auto moved_eager_tensor =
|
auto moved_eager_tensor =
|
||||||
eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
|
||||||
return moved_eager_tensor;
|
return moved_eager_tensor;
|
||||||
} else if (
|
} else if (device && device->type() == c10::kLazy && device->has_index() &&
|
||||||
device && device->type() == c10::kLazy && device->has_index() &&
|
device->index() != self.device().index()) {
|
||||||
device->index() != self.device().index()) {
|
|
||||||
// Case 3: lazy:0 -> lazy:1
|
// Case 3: lazy:0 -> lazy:1
|
||||||
|
|
||||||
// TODO(whc) what do we actually want to do here?
|
// TODO(whc) what do we actually want to do here?
|
||||||
// option 1: materialize, move eager tensor, create new lazy tensor
|
// option 1: materialize, move eager tensor, create new lazy tensor
|
||||||
// - this should be our default, as it is what would happen before we implemented _to_copy
|
// - this should be our default, as it is what would happen before we
|
||||||
|
// implemented _to_copy
|
||||||
// - actually combines case 1 + case 2
|
// - actually combines case 1 + case 2
|
||||||
// option 2: support multiple devices inside one lazy/TS executor (case 4)
|
// option 2: support multiple devices inside one lazy/TS executor (case 4)
|
||||||
// - but: we may have other assumptions that there is just one device per executor? so don't take this lightly
|
// - but: we may have other assumptions that there is just one device
|
||||||
|
// per executor? so don't take this lightly
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(lazy_self);
|
TORCH_INTERNAL_ASSERT(lazy_self);
|
||||||
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
|
||||||
// we move the eager tensor to the 'eager' equivalent of our lazy device
|
// we move the eager tensor to the 'eager' equivalent of our lazy device
|
||||||
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use
|
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is
|
||||||
|
// what we use
|
||||||
auto eager_device = c10::Device(
|
auto eager_device = c10::Device(
|
||||||
torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
|
torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
|
||||||
options = options.device(eager_device);
|
options = options.device(eager_device);
|
||||||
|
@ -305,12 +312,14 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph)
|
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy
|
||||||
|
// graph)
|
||||||
|
|
||||||
// Note: captured _to_copy will be executed with real eager tensors, not lazy tensors.
|
// Note: captured _to_copy will be executed with real eager tensors, not
|
||||||
// We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to
|
// lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this
|
||||||
// convert an eager tensor back to a lazy one inside the torchscript executor
|
// captured IR, or we will try to convert an eager tensor back to a lazy one
|
||||||
// lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument
|
// inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so
|
||||||
|
// we can safely drop the device argument
|
||||||
device = c10::nullopt;
|
device = c10::nullopt;
|
||||||
|
|
||||||
auto shapes = torch::lazy::compute_shape__to_copy(
|
auto shapes = torch::lazy::compute_shape__to_copy(
|
||||||
|
@ -327,257 +336,297 @@ at::Tensor LazyNativeFunctions::_to_copy(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::_unsafe_view(
|
at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self,
|
||||||
const at::Tensor& self, at::IntArrayRef size) {
|
at::IntArrayRef size) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size));
|
return LazyNativeFunctions::view_copy_symint(self,
|
||||||
|
c10::fromIntArrayRefSlow(size));
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
|
at::Tensor LazyNativeFunctions::t(const at::Tensor &self) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
|
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) {
|
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor &self,
|
||||||
|
int64_t dim) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
TORCH_INTERNAL_ASSERT(common_device);
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
|
|
||||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
LazyTensorPtr lazy_self =
|
||||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim);
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
torch::lazy::NodePtr node =
|
||||||
|
torch::lazy::ReuseNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim);
|
||||||
if (!node) {
|
if (!node) {
|
||||||
auto self_meta = to_meta(self);
|
auto self_meta = to_meta(self);
|
||||||
auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim);
|
auto out_meta =
|
||||||
|
at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim);
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> shapes;
|
std::vector<torch::lazy::Shape> shapes;
|
||||||
for (const auto & shape : out_meta) {
|
for (const auto &shape : out_meta) {
|
||||||
shapes.push_back(
|
shapes.push_back(
|
||||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if(torch::lazy::symbolicShapeEnabled()){
|
if (torch::lazy::symbolicShapeEnabled()) {
|
||||||
std::vector<torch::jit::IValue> inputs = { self, dim };
|
std::vector<torch::jit::IValue> inputs = {self, dim};
|
||||||
const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]";
|
const char *schema_str =
|
||||||
|
"aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]";
|
||||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
node = torch::lazy::MakeNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim, std::move(shapes));
|
node = torch::lazy::MakeNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim,
|
||||||
CacheNode(node);
|
std::move(shapes));
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor> result;
|
|
||||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
|
||||||
result.push_back(
|
|
||||||
torch::lazy::CreateAtenFromLtcTensor(
|
|
||||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor> LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) {
|
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
|
||||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
|
||||||
TORCH_INTERNAL_ASSERT(common_device);
|
|
||||||
|
|
||||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
|
||||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim);
|
|
||||||
if (!node) {
|
|
||||||
auto self_meta = to_meta(self);
|
|
||||||
auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim);
|
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> shapes;
|
|
||||||
for (const auto & shape : out_meta) {
|
|
||||||
shapes.push_back(
|
|
||||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
if(torch::lazy::symbolicShapeEnabled()){
|
|
||||||
std::vector<torch::jit::IValue> inputs = { self, split_sizes, dim };
|
|
||||||
const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]";
|
|
||||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
|
||||||
}
|
|
||||||
|
|
||||||
node = torch::lazy::MakeNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes));
|
|
||||||
CacheNode(node);
|
CacheNode(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<at::Tensor> result;
|
std::vector<at::Tensor> result;
|
||||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||||
result.push_back(
|
result.push_back(
|
||||||
torch::lazy::CreateAtenFromLtcTensor(
|
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
|
||||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
torch::lazy::Value(node, i), *common_device)));
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) {
|
std::vector<at::Tensor> LazyNativeFunctions::split_with_sizes_copy_symint(
|
||||||
|
const at::Tensor &self, c10::SymIntArrayRef split_sizes, int64_t dim) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
TORCH_INTERNAL_ASSERT(common_device);
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
|
||||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim);
|
LazyTensorPtr lazy_self =
|
||||||
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitWithSizesCopy>(
|
||||||
|
lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim);
|
||||||
if (!node) {
|
if (!node) {
|
||||||
auto self_meta = to_meta(self);
|
auto self_meta = to_meta(self);
|
||||||
auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim);
|
auto out_meta = at::compositeexplicitautogradnonfunctional::
|
||||||
|
split_with_sizes_copy_symint(self_meta, split_sizes, dim);
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> shapes;
|
std::vector<torch::lazy::Shape> shapes;
|
||||||
for (const auto & shape : out_meta) {
|
for (const auto &shape : out_meta) {
|
||||||
shapes.push_back(
|
shapes.push_back(
|
||||||
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
|
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
|
||||||
);
|
}
|
||||||
|
|
||||||
|
if (torch::lazy::symbolicShapeEnabled()) {
|
||||||
|
std::vector<torch::jit::IValue> inputs = {self, split_sizes, dim};
|
||||||
|
const char *schema_str = "aten::split_with_sizes_copy(Tensor self, "
|
||||||
|
"SymInt[] split_sizes, int dim=0) -> Tensor[]";
|
||||||
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
node = torch::lazy::MakeNode<SplitWithSizesCopy>(
|
||||||
|
lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim,
|
||||||
|
std::move(shapes));
|
||||||
|
CacheNode(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor> result;
|
||||||
|
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||||
|
result.push_back(
|
||||||
|
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
|
||||||
|
torch::lazy::Value(node, i), *common_device)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<at::Tensor>
|
||||||
|
LazyNativeFunctions::split_copy_symint(const at::Tensor &self,
|
||||||
|
c10::SymInt split_size, int64_t dim) {
|
||||||
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
|
LazyTensorPtr lazy_self =
|
||||||
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitCopyTensor>(
|
||||||
|
lazy_self->GetIrValue(), GetSymIntValue(split_size), dim);
|
||||||
|
if (!node) {
|
||||||
|
auto self_meta = to_meta(self);
|
||||||
|
auto out_meta =
|
||||||
|
at::compositeexplicitautogradnonfunctional::split_copy_symint(
|
||||||
|
self_meta, split_size, dim);
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> shapes;
|
||||||
|
for (const auto &shape : out_meta) {
|
||||||
|
shapes.push_back(
|
||||||
|
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
|
||||||
}
|
}
|
||||||
const size_t num_outputs = shapes.size();
|
const size_t num_outputs = shapes.size();
|
||||||
|
|
||||||
if(torch::lazy::symbolicShapeEnabled()){
|
if (torch::lazy::symbolicShapeEnabled()) {
|
||||||
std::vector<torch::jit::IValue> inputs = { self, split_size, dim };
|
std::vector<torch::jit::IValue> inputs = {self, split_size, dim};
|
||||||
const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]";
|
const char *schema_str = "aten::split_copy.Tensor(Tensor self, SymInt "
|
||||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
"split_size, int dim=0) -> Tensor[]";
|
||||||
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
node = torch::lazy::MakeNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs);
|
node = torch::lazy::MakeNode<SplitCopyTensor>(
|
||||||
|
lazy_self->GetIrValue(), GetSymIntValue(split_size), dim,
|
||||||
|
std::move(shapes), num_outputs);
|
||||||
CacheNode(node);
|
CacheNode(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<at::Tensor> result;
|
std::vector<at::Tensor> result;
|
||||||
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
for (size_t i = 0; i < node->num_outputs(); ++i) {
|
||||||
result.push_back(
|
result.push_back(
|
||||||
torch::lazy::CreateAtenFromLtcTensor(
|
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
|
||||||
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
|
torch::lazy::Value(node, i), *common_device)));
|
||||||
)
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices) {
|
at::Tensor LazyNativeFunctions::index(
|
||||||
|
const at::Tensor &self,
|
||||||
|
const c10::List<c10::optional<at::Tensor>> &indices) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
TORCH_INTERNAL_ASSERT(common_device);
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
LazyTensorPtr lazy_self =
|
||||||
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
|
||||||
std::vector<torch::lazy::Value> values;
|
std::vector<torch::lazy::Value> values;
|
||||||
for (const auto & it : indices) {
|
for (const auto &it : indices) {
|
||||||
c10::optional<at::Tensor> tensor = it;
|
c10::optional<at::Tensor> tensor = it;
|
||||||
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
LazyTensorPtr lazy_tensor =
|
||||||
values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||||
|
values.push_back(
|
||||||
|
lazy_tensor
|
||||||
|
? lazy_tensor->GetIrValue()
|
||||||
|
: torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
|
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
|
||||||
|
|
||||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
|
torch::lazy::NodePtr node =
|
||||||
|
torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
|
||||||
|
|
||||||
if (!node) {
|
if (!node) {
|
||||||
auto self_meta = to_meta(self);
|
auto self_meta = to_meta(self);
|
||||||
auto indices_meta = to_meta(indices);
|
auto indices_meta = to_meta(indices);
|
||||||
auto out_meta = at::meta::index(self_meta, indices_meta);
|
auto out_meta = at::meta::index(self_meta, indices_meta);
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
std::vector<torch::lazy::Shape> shapes{
|
||||||
|
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||||
if(torch::lazy::symbolicShapeEnabled()) {
|
if (torch::lazy::symbolicShapeEnabled()) {
|
||||||
std::vector<torch::jit::IValue> inputs = { self, indices };
|
std::vector<torch::jit::IValue> inputs = {self, indices};
|
||||||
const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
|
const char *schema_str =
|
||||||
|
"aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
|
||||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list, std::move(shapes));
|
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list,
|
||||||
|
std::move(shapes));
|
||||||
CacheNode(node);
|
CacheNode(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||||
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
|
at::Tensor LazyNativeFunctions::index_put(
|
||||||
|
const at::Tensor &self, const c10::List<c10::optional<at::Tensor>> &indices,
|
||||||
|
const at::Tensor &values, bool accumulate) {
|
||||||
TORCH_LAZY_FN_COUNTER("lazy::");
|
TORCH_LAZY_FN_COUNTER("lazy::");
|
||||||
auto common_device = torch::lazy::GetBackendDevice(self);
|
auto common_device = torch::lazy::GetBackendDevice(self);
|
||||||
TORCH_INTERNAL_ASSERT(common_device);
|
TORCH_INTERNAL_ASSERT(common_device);
|
||||||
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
LazyTensorPtr lazy_self =
|
||||||
LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
|
||||||
|
LazyTensorPtr lazy_valeus =
|
||||||
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
|
||||||
|
|
||||||
std::vector<torch::lazy::Value> indices_vector;
|
std::vector<torch::lazy::Value> indices_vector;
|
||||||
for (const auto & it : indices) {
|
for (const auto &it : indices) {
|
||||||
c10::optional<at::Tensor> tensor = it;
|
c10::optional<at::Tensor> tensor = it;
|
||||||
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
LazyTensorPtr lazy_tensor =
|
||||||
indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
|
||||||
|
indices_vector.push_back(
|
||||||
|
lazy_tensor
|
||||||
|
? lazy_tensor->GetIrValue()
|
||||||
|
: torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
|
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
|
||||||
|
|
||||||
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate);
|
torch::lazy::NodePtr node =
|
||||||
|
torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list,
|
||||||
|
lazy_valeus->GetIrValue(), accumulate);
|
||||||
|
|
||||||
if (!node) {
|
if (!node) {
|
||||||
auto self_meta = to_meta(self);
|
auto self_meta = to_meta(self);
|
||||||
auto indices_meta = to_meta(indices);
|
auto indices_meta = to_meta(indices);
|
||||||
auto values_meta = to_meta(values);
|
auto values_meta = to_meta(values);
|
||||||
|
|
||||||
auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate);
|
auto out_meta = at::compositeexplicitautograd::index_put(
|
||||||
|
self_meta, indices_meta, values_meta, accumulate);
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
std::vector<torch::lazy::Shape> shapes{
|
||||||
|
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
|
||||||
if(torch::lazy::symbolicShapeEnabled()) {
|
if (torch::lazy::symbolicShapeEnabled()) {
|
||||||
std::vector<torch::jit::IValue> inputs = { self, indices, values };
|
std::vector<torch::jit::IValue> inputs = {self, indices, values};
|
||||||
const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
|
const char *schema_str =
|
||||||
|
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool "
|
||||||
|
"accumulate=False) -> Tensor";
|
||||||
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
applySymbolicShapesOnLT(schema_str, inputs, shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
node = torch::lazy::MakeNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes));
|
node = torch::lazy::MakeNode<IndexPut>(
|
||||||
|
lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(),
|
||||||
|
accumulate, std::move(shapes));
|
||||||
CacheNode(node);
|
CacheNode(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
auto result = torch::lazy::CreateAtenFromLtcTensor(
|
||||||
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
torch::lazy::LazyTensor::Create(std::move(node), *common_device));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This is needed by the torch.tensor constructor.
|
// This is needed by the torch.tensor constructor.
|
||||||
// LazyTensor always opts into functionalization.
|
// LazyTensor always opts into functionalization.
|
||||||
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
|
// "lifting" a tensor for functionalization means wrapping it in a
|
||||||
at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) {
|
// FunctionalTensorWrapper object.
|
||||||
|
at::Tensor LazyNativeFunctions::lift(const at::Tensor &tensor) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
!at::functionalization::impl::isFunctionalTensor(tensor));
|
!at::functionalization::impl::isFunctionalTensor(tensor));
|
||||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) {
|
at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor &tensor) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
!at::functionalization::impl::isFunctionalTensor(tensor));
|
!at::functionalization::impl::isFunctionalTensor(tensor));
|
||||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
// All of the below ops correspond to CompositeExplicitAutograd kernels from core
|
// All of the below ops correspond to CompositeExplicitAutograd kernels from
|
||||||
// that call into view operators internally.
|
// core that call into view operators internally. These are all composite ops
|
||||||
// These are all composite ops that LTC can technically re-use / get for free,
|
// that LTC can technically re-use / get for free, but we need to
|
||||||
// but we need to "functionalize" them to remove the view ops before we can use them.
|
// "functionalize" them to remove the view ops before we can use them.
|
||||||
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
|
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
block_diag)>::call(tensors);
|
block_diag)>::call(tensors);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::new_empty_strided_symint(
|
at::Tensor LazyNativeFunctions::new_empty_strided_symint(
|
||||||
const at::Tensor& self,
|
const at::Tensor &self, c10::SymIntArrayRef size,
|
||||||
c10::SymIntArrayRef size,
|
c10::SymIntArrayRef stride, c10::optional<at::ScalarType> dtype,
|
||||||
c10::SymIntArrayRef stride,
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
c10::optional<at::ScalarType> dtype,
|
|
||||||
c10::optional<at::Layout> layout,
|
|
||||||
c10::optional<at::Device> device,
|
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<bool> pin_memory) {
|
||||||
if (!device || device->type() == c10::DeviceType::Lazy) {
|
if (!device || device->type() == c10::DeviceType::Lazy) {
|
||||||
return at::functionalization::functionalize_aten_op_symint<
|
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
||||||
ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout,
|
new_empty_strided)>::call(self, size, stride, dtype, layout, device,
|
||||||
device, pin_memory);
|
pin_memory);
|
||||||
}
|
}
|
||||||
// For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu")
|
// For cases when device != lazy, for example:
|
||||||
// we need to avoid explicit functionalization. To do that we create regular cpu tensors.
|
// lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit
|
||||||
|
// functionalization. To do that we create regular cpu tensors.
|
||||||
at::Tensor t = at::empty_symint(
|
at::Tensor t = at::empty_symint(
|
||||||
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
|
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
|
||||||
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
|
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
|
||||||
|
@ -585,65 +634,63 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
|
||||||
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
|
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor LazyNativeFunctions::narrow_copy_symint(
|
at::Tensor LazyNativeFunctions::narrow_copy_symint(const at::Tensor &self,
|
||||||
const at::Tensor& self,
|
int64_t dim,
|
||||||
int64_t dim,
|
c10::SymInt start,
|
||||||
c10::SymInt start,
|
c10::SymInt length) {
|
||||||
c10::SymInt length) {
|
|
||||||
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
||||||
narrow_copy)>::call(self, dim, start, length);
|
narrow_copy)>::call(self, dim, start, length);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::pixel_shuffle(
|
at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor &self,
|
||||||
const at::Tensor& self, int64_t upscale_factor) {
|
int64_t upscale_factor) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
pixel_shuffle)>::call(self, upscale_factor);
|
pixel_shuffle)>::call(self, upscale_factor);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::pixel_unshuffle(
|
at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor &self,
|
||||||
const at::Tensor& self, int64_t downscale_factor) {
|
int64_t downscale_factor) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
pixel_unshuffle)>::call(self, downscale_factor);
|
pixel_unshuffle)>::call(self, downscale_factor);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::select_backward(
|
at::Tensor LazyNativeFunctions::select_backward(const at::Tensor &grad_output,
|
||||||
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim,
|
at::IntArrayRef input_sizes,
|
||||||
int64_t index) {
|
int64_t dim, int64_t index) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
select_backward)>::call(grad_output, input_sizes, dim, index);
|
select_backward)>::call(grad_output, input_sizes, dim, index);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::slice_backward_symint(
|
at::Tensor LazyNativeFunctions::slice_backward_symint(
|
||||||
const at::Tensor& grad_output,
|
const at::Tensor &grad_output, at::SymIntArrayRef input_sizes, int64_t dim,
|
||||||
at::SymIntArrayRef input_sizes,
|
c10::SymInt start, c10::SymInt end, c10::SymInt step) {
|
||||||
int64_t dim,
|
|
||||||
c10::SymInt start,
|
|
||||||
c10::SymInt end,
|
|
||||||
c10::SymInt step) {
|
|
||||||
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
|
||||||
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::diagonal_backward(
|
at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor &grad_output,
|
||||||
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset,
|
at::IntArrayRef input_sizes,
|
||||||
int64_t dim1, int64_t dim2) {
|
int64_t offset, int64_t dim1,
|
||||||
|
int64_t dim2) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
|
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::_trilinear(
|
at::Tensor LazyNativeFunctions::_trilinear(
|
||||||
const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3,
|
const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3,
|
||||||
at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3,
|
at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3,
|
||||||
at::IntArrayRef sumdim, int64_t unroll_dim) {
|
at::IntArrayRef sumdim, int64_t unroll_dim) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
|
return at::functionalization::functionalize_aten_op<ATEN_OP(
|
||||||
call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
|
_trilinear)>::call(i1, i2, i3, expand1, expand2, expand3, sumdim,
|
||||||
|
unroll_dim);
|
||||||
}
|
}
|
||||||
at::Tensor LazyNativeFunctions::linalg_pinv(
|
at::Tensor LazyNativeFunctions::linalg_pinv(
|
||||||
const at::Tensor& self, const c10::optional<at::Tensor>& atol,
|
const at::Tensor &self, const c10::optional<at::Tensor> &atol,
|
||||||
const c10::optional<at::Tensor>& rtol, bool hermitian) {
|
const c10::optional<at::Tensor> &rtol, bool hermitian) {
|
||||||
return at::functionalization::functionalize_aten_op<ATEN_OP2(
|
return at::functionalization::functionalize_aten_op<ATEN_OP2(
|
||||||
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
linalg_pinv, atol_rtol_tensor)>::call(self, atol, rtol, hermitian);
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionalize_aten_op can't handle out= ops directly.
|
// functionalize_aten_op can't handle out= ops directly.
|
||||||
// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs.
|
// Instead, we can call the composite kernel from core, and copy and mutations
|
||||||
at::Tensor& LazyNativeFunctions::logsumexp_out(
|
// back to the inputs.
|
||||||
const at::Tensor& self, at::IntArrayRef dim, bool keepdim,
|
at::Tensor &LazyNativeFunctions::logsumexp_out(const at::Tensor &self,
|
||||||
at::Tensor& out) {
|
at::IntArrayRef dim,
|
||||||
|
bool keepdim, at::Tensor &out) {
|
||||||
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
|
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
|
||||||
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
|
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
|
||||||
// directly call the composite kernel from core.
|
// directly call the composite kernel from core.
|
||||||
|
|
|
@ -18,11 +18,10 @@ namespace lazy {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
hash_t OperandHashes(
|
hash_t OperandHashes(const OpList &operands, const c10::ArrayRef<Shape> &shapes,
|
||||||
const OpList& operands, const c10::ArrayRef<Shape>& shapes,
|
const hash_t &seed, bool bakeInSizes) {
|
||||||
const hash_t& seed, bool bakeInSizes) {
|
|
||||||
hash_t hash = seed;
|
hash_t hash = seed;
|
||||||
for (auto& operand : operands) {
|
for (auto &operand : operands) {
|
||||||
if (!operand) {
|
if (!operand) {
|
||||||
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
hash = HashCombine(hash, static_cast<uint64_t>(kNullOpt));
|
||||||
continue;
|
continue;
|
||||||
|
@ -30,7 +29,7 @@ hash_t OperandHashes(
|
||||||
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
|
auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash();
|
||||||
hash = HashCombine(hash, operand_hash);
|
hash = HashCombine(hash, operand_hash);
|
||||||
}
|
}
|
||||||
for (auto& shape : shapes) {
|
for (auto &shape : shapes) {
|
||||||
hash = HashCombine(hash, shape.hash(bakeInSizes));
|
hash = HashCombine(hash, shape.hash(bakeInSizes));
|
||||||
}
|
}
|
||||||
return hash;
|
return hash;
|
||||||
|
@ -38,53 +37,51 @@ hash_t OperandHashes(
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Adds a static hook that is run after every single TorchMlirNode is
|
||||||
// Adds a static hook that is run after every single TorchMlirNode is initialized
|
// initialized
|
||||||
static std::vector<std::function<void(TorchMlirNode*)>> constructor_hooks;
|
static std::vector<std::function<void(TorchMlirNode *)>> constructor_hooks;
|
||||||
void TorchMlirNode::addConstructorHook(std::function<void(TorchMlirNode*)> f) {
|
void TorchMlirNode::addConstructorHook(std::function<void(TorchMlirNode *)> f) {
|
||||||
constructor_hooks.emplace_back(f);
|
constructor_hooks.emplace_back(f);
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirNode::TorchMlirNode(
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
|
||||||
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
|
std::vector<Shape> &&shapes, size_t num_outputs,
|
||||||
hash_t hash_seed)
|
hash_t hash_seed)
|
||||||
: Node(op, operands, std::move(shapes), num_outputs) {
|
: Node(op, operands, std::move(shapes), num_outputs) {
|
||||||
hash_seed = HashCombine(op.hash(), hash_seed);
|
hash_seed = HashCombine(op.hash(), hash_seed);
|
||||||
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
|
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
|
||||||
dag_hash_ =
|
dag_hash_ = (enableDynamicShape()
|
||||||
(enableDynamicShape()
|
? OperandHashes(operands, this->shapes(), hash_seed, false)
|
||||||
? OperandHashes(operands, this->shapes(), hash_seed, false)
|
: shape_hash_);
|
||||||
: shape_hash_);
|
|
||||||
|
|
||||||
for (std::function<void(TorchMlirNode*)>& f : constructor_hooks) {
|
for (std::function<void(TorchMlirNode *)> &f : constructor_hooks) {
|
||||||
f(this);
|
f(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirNode::TorchMlirNode(
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
|
||||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
const std::function<Shape()> &shape_fn,
|
||||||
size_t num_outputs, hash_t hash_seed)
|
size_t num_outputs, hash_t hash_seed)
|
||||||
: TorchMlirNode(
|
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
|
||||||
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
|
hash_seed) {
|
||||||
addComputedShape(shape_fn);
|
addComputedShape(shape_fn);
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirNode::TorchMlirNode(
|
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
|
||||||
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
|
hash_t hash_seed)
|
||||||
: TorchMlirNode(
|
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
|
||||||
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
|
hash_seed) {}
|
||||||
|
|
||||||
TorchMlirNode::TorchMlirNode(
|
TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
|
||||||
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
|
hash_t hash_seed)
|
||||||
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
|
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
|
||||||
|
|
||||||
hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
hash_t TorchMlirNode::hash() const { return dag_hash_; }
|
||||||
|
|
||||||
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
|
||||||
|
|
||||||
|
TorchMlirNode *TorchMlirNode::mlir_node(int index) const {
|
||||||
TorchMlirNode* TorchMlirNode::mlir_node(int index) const {
|
return dynamic_cast<TorchMlirNode *>(operands_.at(index).get());
|
||||||
return dynamic_cast<TorchMlirNode*>(operands_.at(index).get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -107,11 +104,12 @@ TorchMlirTensorList::TorchMlirTensorList(OpList values)
|
||||||
/*num_outputs=*/1,
|
/*num_outputs=*/1,
|
||||||
/*hash_seed=*/kHashSeed) {}
|
/*hash_seed=*/kHashSeed) {}
|
||||||
|
|
||||||
torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
|
torch::lazy::TorchMlirOpVector
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirTensorList::Lower(TorchMlirFunction function,
|
||||||
std::vector<torch::jit::Value*> tensor_list;
|
TorchMlirLoweringContext *loctx) const {
|
||||||
|
std::vector<torch::jit::Value *> tensor_list;
|
||||||
CHECK(!operands().empty());
|
CHECK(!operands().empty());
|
||||||
for (const torch::lazy::Output& operand : operands()) {
|
for (const torch::lazy::Output &operand : operands()) {
|
||||||
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
||||||
}
|
}
|
||||||
auto graph = function->graph();
|
auto graph = function->graph();
|
||||||
|
@ -140,16 +138,17 @@ TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
|
||||||
/*num_outputs=*/1,
|
/*num_outputs=*/1,
|
||||||
/*hash_seed=*/kHashSeed) {}
|
/*hash_seed=*/kHashSeed) {}
|
||||||
|
|
||||||
torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower(
|
torch::lazy::TorchMlirOpVector
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirOptionalTensorList::Lower(TorchMlirFunction function,
|
||||||
std::vector<torch::jit::Value*> tensor_list;
|
TorchMlirLoweringContext *loctx) const {
|
||||||
|
std::vector<torch::jit::Value *> tensor_list;
|
||||||
CHECK(!operands().empty());
|
CHECK(!operands().empty());
|
||||||
for (const torch::lazy::Output& operand : operands()) {
|
for (const torch::lazy::Output &operand : operands()) {
|
||||||
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
tensor_list.emplace_back(loctx->GetOutputOp(operand));
|
||||||
}
|
}
|
||||||
auto graph = function->graph();
|
auto graph = function->graph();
|
||||||
auto listnode =
|
auto listnode = graph->insertNode(graph->createList(
|
||||||
graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list));
|
c10::OptionalType::create(c10::TensorType::get()), tensor_list));
|
||||||
return {listnode->output()};
|
return {listnode->output()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,23 +27,22 @@ namespace lazy {
|
||||||
|
|
||||||
class TORCH_API TorchMlirNode : public torch::lazy::Node {
|
class TORCH_API TorchMlirNode : public torch::lazy::Node {
|
||||||
public:
|
public:
|
||||||
TorchMlirNode(
|
TorchMlirNode(OpKind op, OpList operands, std::vector<Shape> &&shapes,
|
||||||
OpKind op, OpList operands, std::vector<Shape>&& shapes,
|
size_t num_outputs, hash_t hash_seed = kHashSeed);
|
||||||
size_t num_outputs, hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
TorchMlirNode(
|
TorchMlirNode(OpKind op, OpList operands,
|
||||||
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
|
const std::function<Shape()> &shape_fn, size_t num_outputs,
|
||||||
size_t num_outputs, hash_t hash_seed = kHashSeed);
|
hash_t hash_seed = kHashSeed);
|
||||||
|
|
||||||
TorchMlirNode(
|
TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
|
||||||
OpKind op, OpList operands, size_t num_outputs,
|
hash_t hash_seed = kHashSeed);
|
||||||
hash_t hash_seed = kHashSeed);
|
|
||||||
|
|
||||||
TorchMlirNode(
|
TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
|
||||||
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
|
hash_t hash_seed = kHashSeed);
|
||||||
|
|
||||||
// Adds a static hook that is run after every single TorchMlirNode is constructed
|
// Adds a static hook that is run after every single TorchMlirNode is
|
||||||
static void addConstructorHook(std::function<void(TorchMlirNode*)>);
|
// constructed
|
||||||
|
static void addConstructorHook(std::function<void(TorchMlirNode *)>);
|
||||||
|
|
||||||
~TorchMlirNode() override = default;
|
~TorchMlirNode() override = default;
|
||||||
|
|
||||||
|
@ -51,10 +50,10 @@ public:
|
||||||
|
|
||||||
hash_t shapeHash() const override;
|
hash_t shapeHash() const override;
|
||||||
|
|
||||||
TorchMlirNode* mlir_node(int index) const;
|
TorchMlirNode *mlir_node(int index) const;
|
||||||
|
|
||||||
virtual TorchMlirOpVector
|
virtual TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
|
TorchMlirLoweringContext *loctx) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// The hash of the dag WITH size info. Used for shape caching
|
// The hash of the dag WITH size info. Used for shape caching
|
||||||
|
@ -86,22 +85,23 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
|
||||||
TorchMlirTensorList() = delete;
|
TorchMlirTensorList() = delete;
|
||||||
TorchMlirTensorList(OpList values);
|
TorchMlirTensorList(OpList values);
|
||||||
|
|
||||||
torch::lazy::TorchMlirOpVector Lower(
|
torch::lazy::TorchMlirOpVector
|
||||||
TorchMlirFunction function,
|
Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent
|
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also
|
||||||
// optional tensors, so the output type for this op is !torch.list<optional<vtensor>>.
|
// represent optional tensors, so the output type for this op is
|
||||||
|
// !torch.list<optional<vtensor>>.
|
||||||
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
|
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
|
||||||
static OpKind ClassOpKind();
|
static OpKind ClassOpKind();
|
||||||
|
|
||||||
TorchMlirOptionalTensorList() = delete;
|
TorchMlirOptionalTensorList() = delete;
|
||||||
TorchMlirOptionalTensorList(OpList values);
|
TorchMlirOptionalTensorList(OpList values);
|
||||||
|
|
||||||
torch::lazy::TorchMlirOpVector Lower(
|
torch::lazy::TorchMlirOpVector
|
||||||
TorchMlirFunction function,
|
Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
|
|
|
@ -31,21 +31,23 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
TorchMlirOpVector
|
||||||
TorchMlirFunction function, c10::Symbol sym,
|
LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym,
|
||||||
const std::vector<c10::TypePtr> tensor_types,
|
const std::vector<c10::TypePtr> tensor_types,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
const std::vector<torch::jit::NamedValue> &kwarguments) {
|
||||||
// Workaround for ListType::isSubtypeOfExt behavior which leads to
|
// Workaround for ListType::isSubtypeOfExt behavior which leads to
|
||||||
// the problems with JIT schema matching, so we need to keep
|
// the problems with JIT schema matching, so we need to keep
|
||||||
// c10::ListType empty before magic_method->call function call.
|
// c10::ListType empty before magic_method->call function call.
|
||||||
auto dummy_graph = torch::jit::Graph();
|
auto dummy_graph = torch::jit::Graph();
|
||||||
for (auto arg : arguments) {
|
for (auto arg : arguments) {
|
||||||
torch::jit::Value* value = arg.value(dummy_graph);
|
torch::jit::Value *value = arg.value(dummy_graph);
|
||||||
if (value->type()->kind() == c10::TypeKind::ListType) {
|
if (value->type()->kind() == c10::TypeKind::ListType) {
|
||||||
auto list_element_type = value->type()->cast<c10::ListType>()->getElementType();
|
auto list_element_type =
|
||||||
|
value->type()->cast<c10::ListType>()->getElementType();
|
||||||
if (list_element_type->cast<c10::OptionalType>()) {
|
if (list_element_type->cast<c10::OptionalType>()) {
|
||||||
value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get())));
|
value->setType(c10::ListType::create(
|
||||||
|
c10::OptionalType::create(c10::TensorType::get())));
|
||||||
} else {
|
} else {
|
||||||
value->setType(c10::ListType::create(c10::TensorType::get()));
|
value->setType(c10::ListType::create(c10::TensorType::get()));
|
||||||
}
|
}
|
||||||
|
@ -56,25 +58,27 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
std::make_shared<torch::jit::BuiltinFunction>(sym, at::nullopt);
|
||||||
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
|
auto magic_method = std::make_shared<torch::jit::MagicMethod>("", builtin);
|
||||||
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
|
auto ret = magic_method->call({}, *function, arguments, kwarguments, 0);
|
||||||
auto sv = dynamic_cast<torch::jit::SimpleValue*>(ret.get());
|
auto sv = dynamic_cast<torch::jit::SimpleValue *>(ret.get());
|
||||||
CHECK(sv);
|
CHECK(sv);
|
||||||
|
|
||||||
TorchMlirOpVector results;
|
TorchMlirOpVector results;
|
||||||
if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) {
|
if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) {
|
||||||
// Unpack dynamic multi-output operations like aten::split with Tensor[] output type.
|
// Unpack dynamic multi-output operations like aten::split with Tensor[]
|
||||||
// This is required to have consistent input types for multi-output node consumers.
|
// output type. This is required to have consistent input types for
|
||||||
torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size());
|
// multi-output node consumers.
|
||||||
|
torch::jit::Node *node = function->graph()->createListUnpack(
|
||||||
|
sv->getValue(), tensor_types.size());
|
||||||
function->graph()->insertNode(node);
|
function->graph()->insertNode(node);
|
||||||
for (const auto & output : node->outputs()) {
|
for (const auto &output : node->outputs()) {
|
||||||
results.push_back(output);
|
results.push_back(output);
|
||||||
}
|
}
|
||||||
} else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
} else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
|
||||||
// Op returns multiple values and the number of outputs is static and defined
|
// Op returns multiple values and the number of outputs is static and
|
||||||
// by the operation schema.
|
// defined by the operation schema.
|
||||||
const auto tuple_call_result = sv->asTuple({}, *function);
|
const auto tuple_call_result = sv->asTuple({}, *function);
|
||||||
for (const auto& tuple_component : tuple_call_result) {
|
for (const auto &tuple_component : tuple_call_result) {
|
||||||
auto tuple_component_sv =
|
auto tuple_component_sv =
|
||||||
dynamic_cast<torch::jit::SimpleValue*>(tuple_component.get());
|
dynamic_cast<torch::jit::SimpleValue *>(tuple_component.get());
|
||||||
results.push_back(tuple_component_sv->getValue());
|
results.push_back(tuple_component_sv->getValue());
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,7 +88,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
|
|
||||||
// Insert known tensor type information.
|
// Insert known tensor type information.
|
||||||
unsigned tensor_type_idx = 0;
|
unsigned tensor_type_idx = 0;
|
||||||
for (jit::Value* value : results) {
|
for (jit::Value *value : results) {
|
||||||
if (value->type()->kind() == c10::TypeKind::TensorType) {
|
if (value->type()->kind() == c10::TypeKind::TensorType) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
tensor_type_idx < tensor_types.size(), function->graph()->toString(),
|
tensor_type_idx < tensor_types.size(), function->graph()->toString(),
|
||||||
|
@ -97,23 +101,22 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that we use up all the known tensor type information available.
|
// Ensure that we use up all the known tensor type information available.
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(tensor_type_idx == tensor_types.size(), tensor_type_idx,
|
||||||
tensor_type_idx == tensor_types.size(), tensor_type_idx,
|
" known types were injected into jit::Value, but ",
|
||||||
" known types were injected into jit::Value, but ", tensor_types.size(),
|
tensor_types.size(), " were provided from lazy::Node!");
|
||||||
" were provided from lazy::Node!");
|
|
||||||
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector LowerTorchMlirBuiltin(
|
TorchMlirOpVector
|
||||||
TorchMlirFunction function, c10::Symbol sym,
|
LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym,
|
||||||
const c10::ArrayRef<Shape> result_shapes,
|
const c10::ArrayRef<Shape> result_shapes,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments) {
|
const std::vector<torch::jit::NamedValue> &kwarguments) {
|
||||||
std::vector<c10::TypePtr> tensor_types;
|
std::vector<c10::TypePtr> tensor_types;
|
||||||
|
|
||||||
// Generate types with fixed tensor shape information.
|
// Generate types with fixed tensor shape information.
|
||||||
for (const Shape& shape : result_shapes) {
|
for (const Shape &shape : result_shapes) {
|
||||||
tensor_types.push_back(torch::jit::TensorType::create(
|
tensor_types.push_back(torch::jit::TensorType::create(
|
||||||
/*scalar_type=*/shape.scalar_type(),
|
/*scalar_type=*/shape.scalar_type(),
|
||||||
/*device=*/c10::nullopt,
|
/*device=*/c10::nullopt,
|
||||||
|
@ -122,34 +125,34 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
/*requires_grad=*/c10::nullopt));
|
/*requires_grad=*/c10::nullopt));
|
||||||
}
|
}
|
||||||
|
|
||||||
return LowerTorchMlirBuiltin(
|
return LowerTorchMlirBuiltin(function, sym, tensor_types, arguments,
|
||||||
function, sym, tensor_types, arguments, kwarguments);
|
kwarguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector LowerBuiltin(
|
TorchMlirOpVector
|
||||||
const torch::lazy::Node* node, TorchMlirFunction function,
|
LowerBuiltin(const torch::lazy::Node *node, TorchMlirFunction function,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
|
||||||
return LowerTorchMlirBuiltin(
|
return LowerTorchMlirBuiltin(function, node->op().op, node->shapes(),
|
||||||
function, node->op().op, node->shapes(), arguments, kwarguments);
|
arguments, kwarguments);
|
||||||
}
|
}
|
||||||
TorchMlirOpVector LowerBuiltin(
|
TorchMlirOpVector
|
||||||
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
LowerBuiltin(c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
|
||||||
TorchMlirFunction function,
|
TorchMlirFunction function,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
|
||||||
return LowerTorchMlirBuiltin(
|
return LowerTorchMlirBuiltin(function, sym, result_shapes, arguments,
|
||||||
function, sym, result_shapes, arguments, kwarguments);
|
kwarguments);
|
||||||
}
|
}
|
||||||
TorchMlirOpVector LowerBuiltin(
|
TorchMlirOpVector
|
||||||
c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
LowerBuiltin(c10::Symbol sym, const std::vector<c10::TypePtr> types,
|
||||||
TorchMlirFunction function,
|
TorchMlirFunction function,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {}) {
|
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
|
||||||
return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments);
|
return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
c10::TensorType &cast_tensor_type(c10::TypePtr value_type) {
|
||||||
auto tensor_type = value_type->cast<c10::TensorType>();
|
auto tensor_type = value_type->cast<c10::TensorType>();
|
||||||
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!");
|
||||||
|
|
||||||
|
@ -157,8 +160,8 @@ c10::TensorType& cast_tensor_type(c10::TypePtr value_type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<std::vector<int64_t>>
|
c10::optional<std::vector<int64_t>>
|
||||||
get_tensor_type_shape(c10::TensorType& tensor_type) {
|
get_tensor_type_shape(c10::TensorType &tensor_type) {
|
||||||
auto& symbolic_shape = tensor_type.symbolic_sizes();
|
auto &symbolic_shape = tensor_type.symbolic_sizes();
|
||||||
if (!symbolic_shape.rank()) {
|
if (!symbolic_shape.rank()) {
|
||||||
return c10::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -175,21 +178,21 @@ get_tensor_type_shape(c10::TensorType& tensor_type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_copy(c10::TypePtr value_type) {
|
std::vector<torch::lazy::Shape> compute_shape_copy(c10::TypePtr value_type) {
|
||||||
c10::TensorType& tensor_type = cast_tensor_type(value_type);
|
c10::TensorType &tensor_type = cast_tensor_type(value_type);
|
||||||
|
|
||||||
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
||||||
TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!");
|
TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!");
|
||||||
|
|
||||||
auto scalar_type = tensor_type.scalarType();
|
auto scalar_type = tensor_type.scalarType();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(scalar_type.has_value(),
|
||||||
scalar_type.has_value(), "Unable to copy due to lack of scalar type!");
|
"Unable to copy due to lack of scalar type!");
|
||||||
return {Shape(scalar_type.value(), maybe_dims.value())};
|
return {Shape(scalar_type.value(), maybe_dims.value())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_slice(
|
std::vector<torch::lazy::Shape> compute_shape_slice(c10::TypePtr value_type,
|
||||||
c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end,
|
int64_t dim, int64_t start,
|
||||||
int64_t step) {
|
int64_t end, int64_t step) {
|
||||||
c10::TensorType& tensor_type = cast_tensor_type(value_type);
|
c10::TensorType &tensor_type = cast_tensor_type(value_type);
|
||||||
|
|
||||||
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
auto maybe_dims = get_tensor_type_shape(tensor_type);
|
||||||
TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!");
|
TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!");
|
||||||
|
@ -217,13 +220,13 @@ std::vector<torch::lazy::Shape> compute_shape_slice(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto scalar_type = tensor_type.scalarType();
|
auto scalar_type = tensor_type.scalarType();
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(scalar_type.has_value(),
|
||||||
scalar_type.has_value(), "Unable to slice due to lack of scalar type!");
|
"Unable to slice due to lack of scalar type!");
|
||||||
return {Shape(scalar_type.value(), dims)};
|
return {Shape(scalar_type.value(), dims)};
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::Value*
|
torch::jit::Value *GenerateClone(torch::jit::Value *val,
|
||||||
GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
|
TorchMlirFunction function) {
|
||||||
std::vector<torch::jit::NamedValue> clone_arguments;
|
std::vector<torch::jit::NamedValue> clone_arguments;
|
||||||
clone_arguments.emplace_back(val);
|
clone_arguments.emplace_back(val);
|
||||||
|
|
||||||
|
@ -234,20 +237,19 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
|
||||||
return cloned.front();
|
return cloned.front();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateCopy(
|
void GenerateCopy(torch::jit::Value *destination, torch::jit::Value *source,
|
||||||
torch::jit::Value* destination, torch::jit::Value* source,
|
TorchMlirFunction function) {
|
||||||
TorchMlirFunction function) {
|
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(destination);
|
arguments.emplace_back(destination);
|
||||||
arguments.emplace_back(source);
|
arguments.emplace_back(source);
|
||||||
LowerBuiltin(
|
LowerBuiltin(at::aten::copy_,
|
||||||
at::aten::copy_, c10::ArrayRef<Shape>(compute_shape_copy(source->type())),
|
c10::ArrayRef<Shape>(compute_shape_copy(source->type())),
|
||||||
function, arguments);
|
function, arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::jit::Value* GenerateSlice(
|
torch::jit::Value *GenerateSlice(torch::jit::Value *base, int64_t dim,
|
||||||
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
|
int64_t start, int64_t end, int64_t step,
|
||||||
int64_t step, TorchMlirFunction function) {
|
TorchMlirFunction function) {
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(base);
|
arguments.emplace_back(base);
|
||||||
arguments.emplace_back(dim);
|
arguments.emplace_back(dim);
|
||||||
|
@ -255,11 +257,11 @@ torch::jit::Value* GenerateSlice(
|
||||||
arguments.emplace_back(end);
|
arguments.emplace_back(end);
|
||||||
arguments.emplace_back(step);
|
arguments.emplace_back(step);
|
||||||
|
|
||||||
TorchMlirOpVector selected = LowerBuiltin(
|
TorchMlirOpVector selected =
|
||||||
at::aten::slice,
|
LowerBuiltin(at::aten::slice,
|
||||||
c10::ArrayRef<Shape>(
|
c10::ArrayRef<Shape>(compute_shape_slice(base->type(), dim,
|
||||||
compute_shape_slice(base->type(), dim, start, end, step)),
|
start, end, step)),
|
||||||
function, arguments);
|
function, arguments);
|
||||||
TORCH_CHECK_EQ(selected.size(), 1);
|
TORCH_CHECK_EQ(selected.size(), 1);
|
||||||
return selected.front();
|
return selected.front();
|
||||||
}
|
}
|
||||||
|
@ -267,10 +269,10 @@ torch::jit::Value* GenerateSlice(
|
||||||
// Node Lowerings
|
// Node Lowerings
|
||||||
|
|
||||||
// Default Node Lowering
|
// Default Node Lowering
|
||||||
TorchMlirOpVector TorchMlirNode::Lower(
|
TorchMlirOpVector TorchMlirNode::Lower(TorchMlirFunction function,
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
for (const torch::lazy::Output& output : operands()) {
|
for (const torch::lazy::Output &output : operands()) {
|
||||||
arguments.emplace_back(loctx->GetOutputOp(output));
|
arguments.emplace_back(loctx->GetOutputOp(output));
|
||||||
}
|
}
|
||||||
return LowerBuiltin(this, function, arguments);
|
return LowerBuiltin(this, function, arguments);
|
||||||
|
@ -280,19 +282,19 @@ TorchMlirOpVector TorchMlirNode::Lower(
|
||||||
|
|
||||||
// Non-native nodes
|
// Non-native nodes
|
||||||
|
|
||||||
TorchMlirOpVector
|
TorchMlirOpVector Cast::Lower(TorchMlirFunction function,
|
||||||
Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
arguments.emplace_back(dtype);
|
arguments.emplace_back(dtype);
|
||||||
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
|
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector DeviceData::Lower(
|
TorchMlirOpVector DeviceData::Lower(TorchMlirFunction function,
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
auto infoptr = data_->info();
|
auto infoptr = data_->info();
|
||||||
auto deviceDataInfoPtr =
|
auto deviceDataInfoPtr =
|
||||||
(torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr;
|
(torch::lazy::LazyGraphExecutor::DeviceDataInfo *)infoptr;
|
||||||
if (GRAPH_DUMP_ENABLED) {
|
if (GRAPH_DUMP_ENABLED) {
|
||||||
LOG(ERROR) << "Lowering device data node, tensor id "
|
LOG(ERROR) << "Lowering device data node, tensor id "
|
||||||
<< deviceDataInfoPtr->tensor_id << std::endl;
|
<< deviceDataInfoPtr->tensor_id << std::endl;
|
||||||
|
@ -300,8 +302,8 @@ TorchMlirOpVector DeviceData::Lower(
|
||||||
return {loctx->GetParameter(data_)};
|
return {loctx->GetParameter(data_)};
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector Scalar::Lower(
|
TorchMlirOpVector Scalar::Lower(TorchMlirFunction function,
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
auto options =
|
auto options =
|
||||||
at::TensorOptions()
|
at::TensorOptions()
|
||||||
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
|
||||||
|
@ -309,8 +311,8 @@ TorchMlirOpVector Scalar::Lower(
|
||||||
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
|
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector Expand::Lower(
|
TorchMlirOpVector Expand::Lower(TorchMlirFunction function,
|
||||||
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
|
||||||
arguments.emplace_back(size);
|
arguments.emplace_back(size);
|
||||||
|
|
|
@ -18,14 +18,14 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
typedef std::vector<torch::jit::Value*> TorchMlirOpVector;
|
typedef std::vector<torch::jit::Value *> TorchMlirOpVector;
|
||||||
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
typedef std::shared_ptr<torch::jit::GraphFunction> TorchMlirFunction;
|
||||||
|
|
||||||
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
|
TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin(
|
||||||
TorchMlirFunction function, c10::Symbol sym,
|
TorchMlirFunction function, c10::Symbol sym,
|
||||||
const c10::ArrayRef<Shape> result_shapes,
|
const c10::ArrayRef<Shape> result_shapes,
|
||||||
const std::vector<torch::jit::NamedValue>& arguments,
|
const std::vector<torch::jit::NamedValue> &arguments,
|
||||||
const std::vector<torch::jit::NamedValue>& kwarguments = {});
|
const std::vector<torch::jit::NamedValue> &kwarguments = {});
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -2,18 +2,16 @@
|
||||||
|
|
||||||
#include <torch/csrc/lazy/core/ir_builder.h>
|
#include <torch/csrc/lazy/core/ir_builder.h>
|
||||||
|
|
||||||
#include "device_data.h"
|
|
||||||
#include "../backend_impl.h"
|
#include "../backend_impl.h"
|
||||||
|
#include "device_data.h"
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
||||||
: TorchMlirNode(
|
: TorchMlirNode(ClassOpKind(), data->shape(),
|
||||||
ClassOpKind(),
|
/*num_outputs=*/1,
|
||||||
data->shape(),
|
/*hash_seed=*/static_cast<uint32_t>(101)),
|
||||||
/*num_outputs=*/1,
|
|
||||||
/*hash_seed=*/static_cast<uint32_t>(101)),
|
|
||||||
data_(std::move(data)) {
|
data_(std::move(data)) {
|
||||||
propagate_name();
|
propagate_name();
|
||||||
}
|
}
|
||||||
|
@ -21,9 +19,11 @@ DeviceData::DeviceData(std::shared_ptr<BackendData> data)
|
||||||
void DeviceData::propagate_name() {
|
void DeviceData::propagate_name() {
|
||||||
if (data_ && name_ != "") {
|
if (data_ && name_ != "") {
|
||||||
// Add device data name to backend data
|
// Add device data name to backend data
|
||||||
TorchMlirBackendData* mlir_data = dynamic_cast<TorchMlirBackendData*>(data_.get());
|
TorchMlirBackendData *mlir_data =
|
||||||
|
dynamic_cast<TorchMlirBackendData *>(data_.get());
|
||||||
TORCH_CHECK(mlir_data);
|
TORCH_CHECK(mlir_data);
|
||||||
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
|
auto *info =
|
||||||
|
dynamic_cast<TorchMlirBackendData::Info *>(mlir_data->mlir_info());
|
||||||
TORCH_CHECK(info);
|
TORCH_CHECK(info);
|
||||||
info->name = name_;
|
info->name = name_;
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ void DeviceData::SetData(std::shared_ptr<BackendData> data) {
|
||||||
propagate_name();
|
propagate_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeviceData::SetName(const std::string& name) {
|
void DeviceData::SetName(const std::string &name) {
|
||||||
name_ = name;
|
name_ = name;
|
||||||
propagate_name();
|
propagate_name();
|
||||||
}
|
}
|
||||||
|
@ -43,12 +43,12 @@ std::string DeviceData::ToString() const {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << TorchMlirNode::ToString() << ", device=" << data_->device();
|
ss << TorchMlirNode::ToString() << ", device=" << data_->device();
|
||||||
if (name_ != "") {
|
if (name_ != "") {
|
||||||
ss << ", name=" << name_;
|
ss << ", name=" << name_;
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
const DeviceData* DeviceData::Cast(const Node* node) {
|
const DeviceData *DeviceData::Cast(const Node *node) {
|
||||||
return NodeCast<DeviceData>(node);
|
return NodeCast<DeviceData>(node);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ NodePtr DeviceData::Create(std::shared_ptr<BackendData> data) {
|
||||||
// Ditching the old data_ is safe because tracing is done iteration
|
// Ditching the old data_ is safe because tracing is done iteration
|
||||||
// by iteration, and after we lauch the async device execution for the
|
// by iteration, and after we lauch the async device execution for the
|
||||||
// previous iteration, data_ in DeviceData nodes are not needed anymore.
|
// previous iteration, data_ in DeviceData nodes are not needed anymore.
|
||||||
DeviceData* device_data = static_cast<DeviceData*>(node.get());
|
DeviceData *device_data = static_cast<DeviceData *>(node.get());
|
||||||
device_data->SetData(data);
|
device_data->SetData(data);
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,15 +6,12 @@
|
||||||
#include <torch/csrc/lazy/backend/backend_data.h>
|
#include <torch/csrc/lazy/backend/backend_data.h>
|
||||||
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
|
||||||
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class TORCH_API DeviceData : public TorchMlirNode {
|
class TORCH_API DeviceData : public TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
static OpKind ClassOpKind() {
|
static OpKind ClassOpKind() { return ltc_device_data; }
|
||||||
return ltc_device_data;
|
|
||||||
}
|
|
||||||
|
|
||||||
explicit DeviceData(std::shared_ptr<BackendData> data);
|
explicit DeviceData(std::shared_ptr<BackendData> data);
|
||||||
|
|
||||||
|
@ -27,22 +24,23 @@ class TORCH_API DeviceData : public TorchMlirNode {
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
const std::shared_ptr<BackendData>& data() const { return data_; }
|
const std::shared_ptr<BackendData> &data() const { return data_; }
|
||||||
|
|
||||||
void SetData(std::shared_ptr<BackendData> data);
|
void SetData(std::shared_ptr<BackendData> data);
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
static const DeviceData* Cast(const Node* node);
|
static const DeviceData *Cast(const Node *node);
|
||||||
|
|
||||||
// To reuse IR nodes, use this method to create DeviceData nodes
|
// To reuse IR nodes, use this method to create DeviceData nodes
|
||||||
// instead of calling the constructor directly.
|
// instead of calling the constructor directly.
|
||||||
static NodePtr Create(std::shared_ptr<BackendData> data);
|
static NodePtr Create(std::shared_ptr<BackendData> data);
|
||||||
|
|
||||||
const std::string& GetName() const { return name_; }
|
const std::string &GetName() const { return name_; }
|
||||||
void SetName(const std::string& name);
|
void SetName(const std::string &name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void propagate_name();
|
void propagate_name();
|
||||||
|
|
||||||
std::shared_ptr<BackendData> data_;
|
std::shared_ptr<BackendData> data_;
|
||||||
|
|
|
@ -15,12 +15,8 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
Generic::Generic(
|
Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs,
|
||||||
OpKind op,
|
hash_t hash_seed)
|
||||||
OpList operands,
|
|
||||||
Shape shape,
|
|
||||||
size_t num_outputs,
|
|
||||||
hash_t hash_seed)
|
|
||||||
: TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed),
|
: TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed),
|
||||||
hash_seed_(hash_seed) {}
|
hash_seed_(hash_seed) {}
|
||||||
|
|
||||||
|
|
|
@ -23,15 +23,11 @@ namespace lazy {
|
||||||
// captured by the LowerFn), but they should instead create a dedicated IR node.
|
// captured by the LowerFn), but they should instead create a dedicated IR node.
|
||||||
// Doing the former would limit IR introspection.
|
// Doing the former would limit IR introspection.
|
||||||
class TORCH_API Generic : public TorchMlirNode {
|
class TORCH_API Generic : public TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
Generic(
|
Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1,
|
||||||
OpKind op,
|
hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
|
||||||
OpList operands,
|
|
||||||
Shape shape,
|
|
||||||
size_t num_outputs = 1,
|
|
||||||
hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hash_t hash_seed_;
|
hash_t hash_seed_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
IndexTensor::IndexTensor(const torch::lazy::Value& self,
|
IndexTensor::IndexTensor(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices,
|
const torch::lazy::Value &indices,
|
||||||
std::vector<torch::lazy::Shape>&& shapes)
|
std::vector<torch::lazy::Shape> &&shapes)
|
||||||
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
|
: torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(),
|
||||||
OpList{self, indices}, std::move(shapes),
|
OpList{self, indices}, std::move(shapes),
|
||||||
/* num_outputs */ 1, torch::lazy::MHash()) {}
|
/* num_outputs */ 1, torch::lazy::MHash()) {}
|
||||||
|
@ -25,13 +25,13 @@ std::string IndexTensor::ToString() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IndexTensor::CanBeReused(const torch::lazy::Value& self,
|
bool IndexTensor::CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices) const {
|
const torch::lazy::Value &indices) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
@ -49,10 +49,10 @@ TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function,
|
||||||
return index_out;
|
return index_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexPut::IndexPut(const torch::lazy::Value& self,
|
IndexPut::IndexPut(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices,
|
const torch::lazy::Value &indices,
|
||||||
const torch::lazy::Value& values, bool accumulate,
|
const torch::lazy::Value &values, bool accumulate,
|
||||||
std::vector<torch::lazy::Shape>&& shapes)
|
std::vector<torch::lazy::Shape> &&shapes)
|
||||||
: torch::lazy::TorchMlirNode(
|
: torch::lazy::TorchMlirNode(
|
||||||
IndexPut::ClassOpKind(), OpList{self, indices, values},
|
IndexPut::ClassOpKind(), OpList{self, indices, values},
|
||||||
std::move(shapes),
|
std::move(shapes),
|
||||||
|
@ -66,15 +66,15 @@ std::string IndexPut::ToString() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IndexPut::CanBeReused(const torch::lazy::Value& self,
|
bool IndexPut::CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices,
|
const torch::lazy::Value &indices,
|
||||||
const torch::lazy::Value& values,
|
const torch::lazy::Value &values,
|
||||||
bool accumulate) const {
|
bool accumulate) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
@ -95,5 +95,5 @@ TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function,
|
||||||
return index_out;
|
return index_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -15,44 +15,44 @@ namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
class IndexTensor : public torch::lazy::TorchMlirNode {
|
class IndexTensor : public torch::lazy::TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
static torch::lazy::OpKind ClassOpKind() {
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
return torch::lazy::OpKind(at::aten::index);
|
return torch::lazy::OpKind(at::aten::index);
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
IndexTensor(const torch::lazy::Value &self, const torch::lazy::Value &indices,
|
||||||
std::vector<torch::lazy::Shape>&& shapes);
|
std::vector<torch::lazy::Shape> &&shapes);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
bool CanBeReused(const torch::lazy::Value& self,
|
bool CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices) const;
|
const torch::lazy::Value &indices) const;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IndexPut : public torch::lazy::TorchMlirNode {
|
class IndexPut : public torch::lazy::TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
static torch::lazy::OpKind ClassOpKind() {
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
return torch::lazy::OpKind(at::aten::index_put);
|
return torch::lazy::OpKind(at::aten::index_put);
|
||||||
}
|
}
|
||||||
|
|
||||||
IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices,
|
IndexPut(const torch::lazy::Value &self, const torch::lazy::Value &indices,
|
||||||
const torch::lazy::Value& values, bool accumulate,
|
const torch::lazy::Value &values, bool accumulate,
|
||||||
std::vector<torch::lazy::Shape>&& shapes);
|
std::vector<torch::lazy::Shape> &&shapes);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
bool CanBeReused(const torch::lazy::Value& self,
|
bool CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& indices,
|
const torch::lazy::Value &indices,
|
||||||
const torch::lazy::Value& values, bool accumulate) const;
|
const torch::lazy::Value &values, bool accumulate) const;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
bool accumulate;
|
bool accumulate;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
IValueConstant::IValueConstant(const c10::IValue& value)
|
IValueConstant::IValueConstant(const c10::IValue &value)
|
||||||
: torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{},
|
: torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{},
|
||||||
std::vector<Shape>{},
|
std::vector<Shape>{},
|
||||||
/* num_outputs */ 1, torch::lazy::MHash()),
|
/* num_outputs */ 1, torch::lazy::MHash()),
|
||||||
|
@ -28,9 +28,9 @@ std::string IValueConstant::ToString() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function,
|
TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
return {loctx->graph()->insertConstant(value)};
|
return {loctx->graph()->insertConstant(value)};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -18,20 +18,20 @@ namespace lazy {
|
||||||
// parameter which is helpful in different usecases when we need custom
|
// parameter which is helpful in different usecases when we need custom
|
||||||
// native ops lowering to torch-mlir IR nodes.
|
// native ops lowering to torch-mlir IR nodes.
|
||||||
class IValueConstant : public torch::lazy::TorchMlirNode {
|
class IValueConstant : public torch::lazy::TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
static torch::lazy::OpKind ClassOpKind() {
|
static torch::lazy::OpKind ClassOpKind() {
|
||||||
return torch::lazy::OpKind(at::prim::Constant);
|
return torch::lazy::OpKind(at::prim::Constant);
|
||||||
}
|
}
|
||||||
|
|
||||||
IValueConstant(const c10::IValue& value);
|
IValueConstant(const c10::IValue &value);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
c10::IValue value;
|
c10::IValue value;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -13,10 +13,10 @@ namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
SplitWithSizesCopy::SplitWithSizesCopy(
|
SplitWithSizesCopy::SplitWithSizesCopy(
|
||||||
const torch::lazy::Value& self, const ::std::vector<int64_t>& split_sizes,
|
const torch::lazy::Value &self, const ::std::vector<int64_t> &split_sizes,
|
||||||
const int64_t& dim, std::vector<torch::lazy::Shape>&& shapes)
|
const int64_t &dim, std::vector<torch::lazy::Shape> &&shapes)
|
||||||
: torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(),
|
: torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(),
|
||||||
OpList{ self }, std::move(shapes),
|
OpList{self}, std::move(shapes),
|
||||||
split_sizes.size() /* num_outputs */,
|
split_sizes.size() /* num_outputs */,
|
||||||
torch::lazy::MHash(split_sizes, dim)),
|
torch::lazy::MHash(split_sizes, dim)),
|
||||||
split_sizes(split_sizes), dim(dim) {}
|
split_sizes(split_sizes), dim(dim) {}
|
||||||
|
@ -29,15 +29,15 @@ std::string SplitWithSizesCopy::ToString() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self,
|
bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value &self,
|
||||||
const ::std::vector<int64_t>& split_sizes,
|
const ::std::vector<int64_t> &split_sizes,
|
||||||
const int64_t& dim) const {
|
const int64_t &dim) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector
|
TorchMlirOpVector
|
||||||
SplitWithSizesCopy::Lower(TorchMlirFunction function,
|
SplitWithSizesCopy::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
@ -55,13 +55,13 @@ SplitWithSizesCopy::Lower(TorchMlirFunction function,
|
||||||
return split_with_sizes_copy_out;
|
return split_with_sizes_copy_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self,
|
SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& split_size,
|
const torch::lazy::Value &split_size,
|
||||||
const int64_t& dim,
|
const int64_t &dim,
|
||||||
std::vector<torch::lazy::Shape>&& shapes,
|
std::vector<torch::lazy::Shape> &&shapes,
|
||||||
const size_t num_outputs)
|
const size_t num_outputs)
|
||||||
: torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(),
|
: torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(),
|
||||||
OpList{ self, split_size }, std::move(shapes),
|
OpList{self, split_size}, std::move(shapes),
|
||||||
num_outputs, torch::lazy::MHash(dim)),
|
num_outputs, torch::lazy::MHash(dim)),
|
||||||
dim(dim) {}
|
dim(dim) {}
|
||||||
|
|
||||||
|
@ -72,15 +72,15 @@ std::string SplitCopyTensor::ToString() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self,
|
bool SplitCopyTensor::CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& split_size,
|
const torch::lazy::Value &split_size,
|
||||||
const int64_t& dim) const {
|
const int64_t &dim) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector
|
TorchMlirOpVector
|
||||||
SplitCopyTensor::Lower(TorchMlirFunction function,
|
SplitCopyTensor::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
|
|
@ -20,19 +20,19 @@ public:
|
||||||
return torch::lazy::OpKind(at::aten::split_with_sizes_copy);
|
return torch::lazy::OpKind(at::aten::split_with_sizes_copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
SplitWithSizesCopy(const torch::lazy::Value& self,
|
SplitWithSizesCopy(const torch::lazy::Value &self,
|
||||||
const ::std::vector<int64_t>& split_sizes,
|
const ::std::vector<int64_t> &split_sizes,
|
||||||
const int64_t& dim,
|
const int64_t &dim,
|
||||||
std::vector<torch::lazy::Shape>&& shapes);
|
std::vector<torch::lazy::Shape> &&shapes);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
bool CanBeReused(const torch::lazy::Value& self,
|
bool CanBeReused(const torch::lazy::Value &self,
|
||||||
const ::std::vector<int64_t>& split_sizes,
|
const ::std::vector<int64_t> &split_sizes,
|
||||||
const int64_t& dim) const;
|
const int64_t &dim) const;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
std::vector<int64_t> split_sizes;
|
std::vector<int64_t> split_sizes;
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
@ -44,19 +44,19 @@ public:
|
||||||
return torch::lazy::OpKind(at::aten::split_copy);
|
return torch::lazy::OpKind(at::aten::split_copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
SplitCopyTensor(const torch::lazy::Value& self,
|
SplitCopyTensor(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& split_size, const int64_t& dim,
|
const torch::lazy::Value &split_size, const int64_t &dim,
|
||||||
std::vector<torch::lazy::Shape>&& shapes,
|
std::vector<torch::lazy::Shape> &&shapes,
|
||||||
const size_t num_outputs = 1);
|
const size_t num_outputs = 1);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
bool CanBeReused(const torch::lazy::Value& self,
|
bool CanBeReused(const torch::lazy::Value &self,
|
||||||
const torch::lazy::Value& split_size,
|
const torch::lazy::Value &split_size,
|
||||||
const int64_t& dim) const;
|
const int64_t &dim) const;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
};
|
};
|
||||||
|
|
|
@ -17,61 +17,65 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
|
// This IR was copied from code-generated output, but the entire _to_copy
|
||||||
// This IR was copied from code-generated output, but the entire _to_copy operator
|
// operator cannot be trivially code genereated since it is only desirable to
|
||||||
// cannot be trivially code genereated since it is only desirable to capture IR for
|
// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the
|
||||||
// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke
|
// others it is difficult to even invoke the aten/eager fallback necessitating
|
||||||
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
|
// directly implementing the right to(device) behavior
|
||||||
class ToCopy : public torch::lazy::TorchMlirNode {
|
class ToCopy : public torch::lazy::TorchMlirNode {
|
||||||
public:
|
public:
|
||||||
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
|
ToCopy(const torch::lazy::Value &self,
|
||||||
: torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy),
|
const c10::optional<at::ScalarType> &dtype,
|
||||||
{self}, std::move(shapes),
|
const c10::optional<at::Layout> &layout,
|
||||||
/* num_outputs */ 1,
|
const c10::optional<at::Device> &device,
|
||||||
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)),
|
const c10::optional<bool> &pin_memory, const bool &non_blocking,
|
||||||
|
const c10::optional<at::MemoryFormat> &memory_format,
|
||||||
|
std::vector<torch::lazy::Shape> &&shapes)
|
||||||
|
: torch::lazy::TorchMlirNode(
|
||||||
|
torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes),
|
||||||
|
/* num_outputs */ 1,
|
||||||
|
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking,
|
||||||
|
memory_format)),
|
||||||
|
|
||||||
dtype(dtype),
|
dtype(dtype), layout(layout), device(device), pin_memory(pin_memory),
|
||||||
layout(layout),
|
non_blocking(non_blocking), memory_format(memory_format) {}
|
||||||
device(device),
|
|
||||||
pin_memory(pin_memory),
|
|
||||||
non_blocking(non_blocking),
|
|
||||||
memory_format(memory_format) {}
|
|
||||||
|
|
||||||
std::string ToString() const override {
|
std::string ToString() const override {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << torch::lazy::TorchMlirNode::ToString();
|
ss << torch::lazy::TorchMlirNode::ToString();
|
||||||
if (dtype.has_value()) {
|
if (dtype.has_value()) {
|
||||||
ss << ", dtype=" << dtype.value();
|
ss << ", dtype=" << dtype.value();
|
||||||
} else {
|
} else {
|
||||||
ss << ", dtype=null";
|
ss << ", dtype=null";
|
||||||
}
|
}
|
||||||
if (layout.has_value()) {
|
if (layout.has_value()) {
|
||||||
ss << ", layout=" << layout.value();
|
ss << ", layout=" << layout.value();
|
||||||
} else {
|
} else {
|
||||||
ss << ", layout=null";
|
ss << ", layout=null";
|
||||||
}
|
}
|
||||||
if (device.has_value()) {
|
if (device.has_value()) {
|
||||||
ss << ", device=" << device.value();
|
ss << ", device=" << device.value();
|
||||||
} else {
|
} else {
|
||||||
ss << ", device=null";
|
ss << ", device=null";
|
||||||
}
|
}
|
||||||
if (pin_memory.has_value()) {
|
if (pin_memory.has_value()) {
|
||||||
ss << ", pin_memory=" << pin_memory.value();
|
ss << ", pin_memory=" << pin_memory.value();
|
||||||
} else {
|
} else {
|
||||||
ss << ", pin_memory=null";
|
ss << ", pin_memory=null";
|
||||||
}
|
}
|
||||||
ss << ", non_blocking=" << non_blocking;
|
ss << ", non_blocking=" << non_blocking;
|
||||||
if (memory_format.has_value()) {
|
if (memory_format.has_value()) {
|
||||||
ss << ", memory_format=" << memory_format.value();
|
ss << ", memory_format=" << memory_format.value();
|
||||||
} else {
|
} else {
|
||||||
ss << ", memory_format=null";
|
ss << ", memory_format=null";
|
||||||
}
|
}
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function,
|
torch::lazy::TorchMlirOpVector
|
||||||
torch::lazy::TorchMlirLoweringContext* loctx) const override {
|
Lower(TorchMlirFunction function,
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
torch::lazy::TorchMlirLoweringContext *loctx) const override {
|
||||||
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
arguments.reserve(1);
|
arguments.reserve(1);
|
||||||
kwarguments.reserve(6);
|
kwarguments.reserve(6);
|
||||||
|
@ -83,11 +87,12 @@ class ToCopy : public torch::lazy::TorchMlirNode {
|
||||||
kwarguments.emplace_back("pin_memory", pin_memory);
|
kwarguments.emplace_back("pin_memory", pin_memory);
|
||||||
kwarguments.emplace_back("non_blocking", non_blocking);
|
kwarguments.emplace_back("non_blocking", non_blocking);
|
||||||
kwarguments.emplace_back("memory_format", memory_format);
|
kwarguments.emplace_back("memory_format", memory_format);
|
||||||
torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
|
torch::lazy::TorchMlirOpVector _to_copy_out =
|
||||||
|
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(),
|
||||||
|
arguments, kwarguments);
|
||||||
TORCH_CHECK_EQ(_to_copy_out.size(), 1);
|
TORCH_CHECK_EQ(_to_copy_out.size(), 1);
|
||||||
|
|
||||||
return _to_copy_out;
|
return _to_copy_out;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<at::ScalarType> dtype;
|
c10::optional<at::ScalarType> dtype;
|
||||||
|
@ -97,5 +102,5 @@ class ToCopy : public torch::lazy::TorchMlirNode {
|
||||||
bool non_blocking;
|
bool non_blocking;
|
||||||
c10::optional<at::MemoryFormat> memory_format;
|
c10::optional<at::MemoryFormat> memory_format;
|
||||||
};
|
};
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim,
|
UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim,
|
||||||
std::vector<torch::lazy::Shape>&& shapes)
|
std::vector<torch::lazy::Shape> &&shapes)
|
||||||
: torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self },
|
: torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{self},
|
||||||
std::move(shapes),
|
std::move(shapes),
|
||||||
self.shape().size(dim), /* num_outputs */
|
self.shape().size(dim), /* num_outputs */
|
||||||
torch::lazy::MHash(dim)),
|
torch::lazy::MHash(dim)),
|
||||||
|
@ -27,13 +27,13 @@ std::string UnbindCopyInt::ToString() const {
|
||||||
return ss.str();
|
return ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self,
|
bool UnbindCopyInt::CanBeReused(const torch::lazy::Value &self,
|
||||||
const int64_t& dim) const {
|
const int64_t &dim) const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function,
|
TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const {
|
TorchMlirLoweringContext *loctx) const {
|
||||||
PRINT_FUNCTION();
|
PRINT_FUNCTION();
|
||||||
std::vector<torch::jit::NamedValue> arguments;
|
std::vector<torch::jit::NamedValue> arguments;
|
||||||
std::vector<torch::jit::NamedValue> kwarguments;
|
std::vector<torch::jit::NamedValue> kwarguments;
|
||||||
|
|
|
@ -20,15 +20,15 @@ public:
|
||||||
return torch::lazy::OpKind(at::aten::unbind_copy);
|
return torch::lazy::OpKind(at::aten::unbind_copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim,
|
UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim,
|
||||||
std::vector<torch::lazy::Shape>&& shapes);
|
std::vector<torch::lazy::Shape> &&shapes);
|
||||||
|
|
||||||
std::string ToString() const override;
|
std::string ToString() const override;
|
||||||
|
|
||||||
bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const;
|
bool CanBeReused(const torch::lazy::Value &self, const int64_t &dim) const;
|
||||||
|
|
||||||
TorchMlirOpVector Lower(TorchMlirFunction function,
|
TorchMlirOpVector Lower(TorchMlirFunction function,
|
||||||
TorchMlirLoweringContext* loctx) const override;
|
TorchMlirLoweringContext *loctx) const override;
|
||||||
|
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
};
|
};
|
||||||
|
|
|
@ -21,21 +21,20 @@ namespace lazy {
|
||||||
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the
|
// TODO(henrytu): Upstream these shape inference functions to PyTorch in the
|
||||||
// future.
|
// future.
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor &self,
|
||||||
const at::Scalar& other,
|
const at::Scalar &other,
|
||||||
const at::Scalar& alpha) {
|
const at::Scalar &alpha) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor &self,
|
||||||
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor& self,
|
const at::Scalar &other,
|
||||||
const at::Scalar& other,
|
const at::Scalar &alpha) {
|
||||||
const at::Scalar& alpha) {
|
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_div(const at::Tensor &self,
|
||||||
const at::Scalar& other) {
|
const at::Scalar &other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,7 +84,7 @@ compute_shape_quantize_per_tensor(const at::Tensor &self, double scale,
|
||||||
return {Shape(dtype, self.sizes().vec())};
|
return {Shape(dtype, self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor& self) {
|
std::vector<torch::lazy::Shape> compute_shape_isinf(const at::Tensor &self) {
|
||||||
return {Shape(at::kBool, self.sizes().vec())};
|
return {Shape(at::kBool, self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,9 +95,8 @@ std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||||
const at::Tensor& self, at::IntArrayRef kernel_size,
|
const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride,
|
||||||
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
|
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) {
|
||||||
bool ceil_mode) {
|
|
||||||
auto in_sizes = self.sizes().vec();
|
auto in_sizes = self.sizes().vec();
|
||||||
std::vector<int64_t> dhw(3, 0);
|
std::vector<int64_t> dhw(3, 0);
|
||||||
std::vector<int64_t> paddings = padding.vec();
|
std::vector<int64_t> paddings = padding.vec();
|
||||||
|
@ -106,18 +104,19 @@ std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||||
std::vector<int64_t> dilations = dilation.vec();
|
std::vector<int64_t> dilations = dilation.vec();
|
||||||
std::vector<int64_t> strides = stride.vec();
|
std::vector<int64_t> strides = stride.vec();
|
||||||
TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ",
|
TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ",
|
||||||
in_sizes);
|
in_sizes);
|
||||||
TORCH_CHECK(kernel_size.size() == 3 &&
|
TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 &&
|
||||||
stride.size() == 3 &&
|
padding.size() == 3 && dilation.size() == 3,
|
||||||
padding.size() == 3 &&
|
"max_pool3d requires 3D operands, but got ", kernel_size, stride,
|
||||||
dilation.size() == 3, "max_pool3d requires 3D operands, but got ",
|
padding, dilation);
|
||||||
kernel_size, stride, padding, dilation);
|
|
||||||
int64_t batch = in_sizes[0];
|
int64_t batch = in_sizes[0];
|
||||||
int64_t channel = in_sizes[1]; // NCDHW
|
int64_t channel = in_sizes[1]; // NCDHW
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html
|
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html
|
||||||
for (auto i = 0UL; i<3; ++i) {
|
for (auto i = 0UL; i < 3; ++i) {
|
||||||
double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] *
|
double out_size = (in_sizes[2 + i] + 2 * paddings[i] -
|
||||||
(ksizes[i] - 1) - 1) / (double)strides[i] + 1;
|
dilations[i] * (ksizes[i] - 1) - 1) /
|
||||||
|
(double)strides[i] +
|
||||||
|
1;
|
||||||
if (ceil_mode)
|
if (ceil_mode)
|
||||||
dhw[i] = (int64_t)std::ceil(out_size);
|
dhw[i] = (int64_t)std::ceil(out_size);
|
||||||
else
|
else
|
||||||
|
@ -129,52 +128,54 @@ std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices_backward(
|
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices_backward(
|
||||||
const at::Tensor & grad_output, const at::Tensor & self,
|
const at::Tensor &grad_output, const at::Tensor &self,
|
||||||
at::IntArrayRef kernel_size, at::IntArrayRef stride,
|
at::IntArrayRef kernel_size, at::IntArrayRef stride,
|
||||||
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode,
|
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode,
|
||||||
const at::Tensor & indices) {
|
const at::Tensor &indices) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_mse_loss_backward(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& grad_output, const at::Tensor& self,
|
compute_shape_mse_loss_backward(const at::Tensor &grad_output,
|
||||||
const at::Tensor& target, int64_t reduction) {
|
const at::Tensor &self,
|
||||||
|
const at::Tensor &target, int64_t reduction) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor &self,
|
||||||
const at::Scalar& other) {
|
const at::Scalar &other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_var(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, at::OptionalIntArrayRef dim,
|
compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim,
|
||||||
const c10::optional<at::Scalar> & correction, bool keepdim) {
|
const c10::optional<at::Scalar> &correction, bool keepdim) {
|
||||||
// Result of variance is scalar tensor.
|
// Result of variance is scalar tensor.
|
||||||
return {Shape(self.scalar_type(), {})};
|
return {Shape(self.scalar_type(), {})};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_nan_to_num(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor & self, c10::optional<double> nan,
|
compute_shape_nan_to_num(const at::Tensor &self, c10::optional<double> nan,
|
||||||
c10::optional<double> posinf, c10::optional<double> neginf) {
|
c10::optional<double> posinf,
|
||||||
|
c10::optional<double> neginf) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, const at::Scalar& min_val,
|
compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val,
|
||||||
const at::Scalar& max_val) {
|
const at::Scalar &max_val) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_hardtanh_backward(
|
std::vector<torch::lazy::Shape> compute_shape_hardtanh_backward(
|
||||||
const at::Tensor& grad_output, const at::Tensor& self,
|
const at::Tensor &grad_output, const at::Tensor &self,
|
||||||
const at::Scalar& min_val, const at::Scalar& max_val) {
|
const at::Scalar &min_val, const at::Scalar &max_val) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_where(const at::Tensor& condition,
|
std::vector<torch::lazy::Shape> compute_shape_where(const at::Tensor &condition,
|
||||||
const at::Tensor& self,
|
const at::Tensor &self,
|
||||||
const at::Tensor& other) {
|
const at::Tensor &other) {
|
||||||
// There are cases like -
|
// There are cases like -
|
||||||
// torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>,
|
// torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>,
|
||||||
// !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>.
|
// !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>.
|
||||||
|
@ -201,32 +202,32 @@ std::vector<torch::lazy::Shape> compute_shape_where(const at::Tensor& condition,
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_bucketize(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
|
compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries,
|
||||||
bool right) {
|
bool out_int32, bool right) {
|
||||||
auto dtype = out_int32 ? at::kInt : at::kLong;
|
auto dtype = out_int32 ? at::kInt : at::kLong;
|
||||||
return {Shape(dtype, self.sizes().vec())};
|
return {Shape(dtype, self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_copy(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_copy(const at::Tensor &self,
|
||||||
const at::Tensor& src,
|
const at::Tensor &src,
|
||||||
bool non_blocking) {
|
bool non_blocking) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_floor_divide(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, const at::Tensor& other) {
|
compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_fmod(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_fmod(const at::Tensor &self,
|
||||||
const at::Scalar& other) {
|
const at::Scalar &other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
|
std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
|
||||||
const at::Tensor& input, const c10::optional<at::Tensor>& weight,
|
const at::Tensor &input, const c10::optional<at::Tensor> &weight,
|
||||||
const c10::optional<at::Tensor>& bias, int64_t N, int64_t C, int64_t HxW,
|
const c10::optional<at::Tensor> &bias, int64_t N, int64_t C, int64_t HxW,
|
||||||
int64_t group, double eps) {
|
int64_t group, double eps) {
|
||||||
|
|
||||||
TORCH_CHECK(input.sizes().size() >= 2,
|
TORCH_CHECK(input.sizes().size() >= 2,
|
||||||
|
@ -244,9 +245,10 @@ std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
|
||||||
return shapes;
|
return shapes;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_im2col(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, at::IntArrayRef kernel_size,
|
compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size,
|
||||||
at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
|
at::IntArrayRef dilation, at::IntArrayRef padding,
|
||||||
|
at::IntArrayRef stride) {
|
||||||
|
|
||||||
auto self_meta = at::native::empty_strided_meta_symint(
|
auto self_meta = at::native::empty_strided_meta_symint(
|
||||||
self.sym_sizes(), self.sym_strides(),
|
self.sym_sizes(), self.sym_strides(),
|
||||||
|
@ -260,8 +262,8 @@ std::vector<torch::lazy::Shape> compute_shape_im2col(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
|
std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
|
||||||
const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean,
|
const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean,
|
||||||
const at::Tensor& rstd, const c10::optional<at::Tensor>& weight, int64_t N,
|
const at::Tensor &rstd, const c10::optional<at::Tensor> &weight, int64_t N,
|
||||||
int64_t C, int64_t HxW, int64_t group, ::std::array<bool, 3> output_mask) {
|
int64_t C, int64_t HxW, int64_t group, ::std::array<bool, 3> output_mask) {
|
||||||
|
|
||||||
TORCH_CHECK(input.sizes().size() >= 2,
|
TORCH_CHECK(input.sizes().size() >= 2,
|
||||||
|
@ -280,8 +282,8 @@ std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
|
||||||
|
|
||||||
return shapes;
|
return shapes;
|
||||||
}
|
}
|
||||||
std::vector<torch::lazy::Shape> compute_shape_remainder(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, const at::Scalar& other) {
|
compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -313,21 +315,22 @@ compute_shape_reflection_pad2d(const at::Tensor &self,
|
||||||
return {Shape(self.scalar_type(), out_sizes)};
|
return {Shape(self.scalar_type(), out_sizes)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_uniform(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, double from, double to,
|
compute_shape_uniform(const at::Tensor &self, double from, double to,
|
||||||
c10::optional<at::Generator> generator) {
|
c10::optional<at::Generator> generator) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_normal_functional(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, double mean, double std,
|
compute_shape_normal_functional(const at::Tensor &self, double mean, double std,
|
||||||
c10::optional<at::Generator> generator) {
|
c10::optional<at::Generator> generator) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_multinomial(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, int64_t num_samples, bool replacement,
|
compute_shape_multinomial(const at::Tensor &self, int64_t num_samples,
|
||||||
c10::optional<at::Generator> generator) {
|
bool replacement,
|
||||||
|
c10::optional<at::Generator> generator) {
|
||||||
// Input tensor can be either 1D or 2D. The last dim of output
|
// Input tensor can be either 1D or 2D. The last dim of output
|
||||||
// should be 'num_samples'. So the output shape can be either
|
// should be 'num_samples'. So the output shape can be either
|
||||||
// [num_samples] or [m, num_samples].
|
// [num_samples] or [m, num_samples].
|
||||||
|
@ -337,35 +340,38 @@ std::vector<torch::lazy::Shape> compute_shape_multinomial(
|
||||||
return {Shape(at::kLong, ishape)};
|
return {Shape(at::kLong, ishape)};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_eye(
|
std::vector<torch::lazy::Shape>
|
||||||
int64_t n, c10::optional<at::ScalarType> dtype,
|
compute_shape_eye(int64_t n, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
auto out_meta =
|
auto out_meta =
|
||||||
at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_eye(
|
std::vector<torch::lazy::Shape>
|
||||||
int64_t n, int64_t m, c10::optional<at::ScalarType> dtype,
|
compute_shape_eye(int64_t n, int64_t m, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
auto out_meta =
|
auto out_meta =
|
||||||
at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Scalar& end, c10::optional<at::ScalarType> dtype,
|
compute_shape_arange(const at::Scalar &end, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
auto out_meta =
|
auto out_meta =
|
||||||
at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||||
const at::Scalar& start, const at::Scalar& end,
|
const at::Scalar &start, const at::Scalar &end,
|
||||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta),
|
auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta),
|
||||||
|
@ -374,7 +380,7 @@ std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_arange(
|
std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||||
const at::Scalar& start, const at::Scalar& end, const at::Scalar& step,
|
const at::Scalar &start, const at::Scalar &end, const at::Scalar &step,
|
||||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
auto out_meta = at::arange(start, end, step, dtype, layout,
|
auto out_meta = at::arange(start, end, step, dtype, layout,
|
||||||
|
@ -383,34 +389,37 @@ std::vector<torch::lazy::Shape> compute_shape_arange(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_full(
|
std::vector<torch::lazy::Shape> compute_shape_full(
|
||||||
at::IntArrayRef size, const at::Scalar& fill_value,
|
at::IntArrayRef size, const at::Scalar &fill_value,
|
||||||
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||||
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
return {
|
return {
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_ones(
|
std::vector<torch::lazy::Shape>
|
||||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
compute_shape_ones(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
return {
|
return {
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_zeros(
|
std::vector<torch::lazy::Shape>
|
||||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
compute_shape_zeros(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
return {
|
return {
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_empty(
|
std::vector<torch::lazy::Shape>
|
||||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
compute_shape_empty(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory,
|
c10::optional<at::Device> device,
|
||||||
c10::optional<at::MemoryFormat> memory_format) {
|
c10::optional<bool> pin_memory,
|
||||||
|
c10::optional<at::MemoryFormat> memory_format) {
|
||||||
return {
|
return {
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
@ -423,20 +432,21 @@ std::vector<torch::lazy::Shape> compute_shape_empty_strided(
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor &self,
|
||||||
const at::Scalar& value) {
|
const at::Scalar &value) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
|
std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor &self,
|
||||||
const at::Tensor& value) {
|
const at::Tensor &value) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_randn(
|
std::vector<torch::lazy::Shape>
|
||||||
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
compute_shape_randn(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<at::Device> device,
|
||||||
|
c10::optional<bool> pin_memory) {
|
||||||
return {
|
return {
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
@ -457,36 +467,39 @@ std::vector<torch::lazy::Shape> compute_shape_randint(
|
||||||
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_resize(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor & self, at::IntArrayRef size,
|
compute_shape_resize(const at::Tensor &self, at::IntArrayRef size,
|
||||||
c10::optional<at::MemoryFormat> memory_format) {
|
c10::optional<at::MemoryFormat> memory_format) {
|
||||||
return {Shape(self.scalar_type(), size.vec())};
|
return {Shape(self.scalar_type(), size.vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_bernoulli(
|
std::vector<torch::lazy::Shape>
|
||||||
const at::Tensor& self, const at::Tensor &p,
|
compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p,
|
||||||
c10::optional<at::Generator> generator) {
|
c10::optional<at::Generator> generator) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
|
std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
|
||||||
const at::Scalar & s, c10::optional<at::ScalarType> dtype,
|
const at::Scalar &s, c10::optional<at::ScalarType> dtype,
|
||||||
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
|
||||||
c10::optional<bool> pin_memory) {
|
c10::optional<bool> pin_memory) {
|
||||||
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_roll(
|
std::vector<torch::lazy::Shape> compute_shape_roll(const at::Tensor &self,
|
||||||
const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) {
|
at::IntArrayRef shifts,
|
||||||
|
at::IntArrayRef dims) {
|
||||||
return {Shape(self.scalar_type(), self.sizes().vec())};
|
return {Shape(self.scalar_type(), self.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<torch::lazy::Shape> compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
std::vector<torch::lazy::Shape> compute_shape_linspace(
|
||||||
auto out_meta =
|
const at::Scalar &start, const at::Scalar &end, int64_t steps,
|
||||||
at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory);
|
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
|
||||||
|
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
|
||||||
|
auto out_meta = at::linspace(start, end, steps, dtype, layout,
|
||||||
|
c10::Device(c10::kMeta), pin_memory);
|
||||||
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
} // namespace lazy
|
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -14,16 +14,16 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
at::Tensor CreateFunctionalizedAtenFromLtcTensor(
|
at::Tensor
|
||||||
const LazyTensorPtr& ltc_tensor) {
|
CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor) {
|
||||||
at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor);
|
at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor);
|
||||||
if (!c10::impl::tls_is_dispatch_key_excluded(
|
if (!c10::impl::tls_is_dispatch_key_excluded(
|
||||||
c10::DispatchKey::Functionalize) &&
|
c10::DispatchKey::Functionalize) &&
|
||||||
!at::functionalization::impl::isFunctionalTensor(tensor)) {
|
!at::functionalization::impl::isFunctionalTensor(tensor)) {
|
||||||
return at::functionalization::impl::to_functional_tensor(tensor);
|
return at::functionalization::impl::to_functional_tensor(tensor);
|
||||||
}
|
}
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -18,7 +18,8 @@ namespace lazy {
|
||||||
// should have explicit tensor functinoalization. Otherwise we can get
|
// should have explicit tensor functinoalization. Otherwise we can get
|
||||||
// unfanctionalized primitives or in the worst case if we apply inplace
|
// unfanctionalized primitives or in the worst case if we apply inplace
|
||||||
// operations to unfunctionalized tensor it won't be captured in LTC graph.
|
// operations to unfunctionalized tensor it won't be captured in LTC graph.
|
||||||
TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
|
TORCH_API at::Tensor
|
||||||
|
CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor);
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
#define UNIMPLEMENTED_FUNCTION_ERROR() \
|
#define UNIMPLEMENTED_FUNCTION_ERROR() \
|
||||||
UNIMPLEMENTED_ERROR( \
|
UNIMPLEMENTED_ERROR("\n\t" << __FILE__ << ":" << __LINE__ << " " \
|
||||||
"\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__)
|
<< __PRETTY_FUNCTION__)
|
||||||
|
|
||||||
#define UNSUPPORTED_ERROR(msg) \
|
#define UNSUPPORTED_ERROR(msg) \
|
||||||
{ \
|
{ \
|
||||||
|
|
|
@ -7,9 +7,9 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
void ConvertScalarImplicit(std::shared_ptr<Graph>& graph) {
|
void ConvertScalarImplicit(std::shared_ptr<Graph> &graph) {
|
||||||
DepthFirstGraphNodeIterator it(graph);
|
DepthFirstGraphNodeIterator it(graph);
|
||||||
for (auto* node = it.next(); node != nullptr; node = it.next()) {
|
for (auto *node = it.next(); node != nullptr; node = it.next()) {
|
||||||
if (node->kind() != c10::aten::ScalarImplicit) {
|
if (node->kind() != c10::aten::ScalarImplicit) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -27,15 +27,13 @@ void ConvertScalarImplicit(std::shared_ptr<Graph>& graph) {
|
||||||
node_type = c10::aten::FloatImplicit;
|
node_type = c10::aten::FloatImplicit;
|
||||||
output_type = FloatType::get();
|
output_type = FloatType::get();
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error("Expected isIntegralType or isFloatingType");
|
||||||
"Expected isIntegralType or isFloatingType");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value * output = graph
|
Value *output = graph->create(node_type, {input})
|
||||||
->create(node_type, {input})
|
->insertBefore(node)
|
||||||
->insertBefore(node)
|
->output()
|
||||||
->output()
|
->setType(output_type);
|
||||||
->setType(output_type);
|
|
||||||
node->output()->replaceAllUsesWith(output);
|
node->output()->replaceAllUsesWith(output);
|
||||||
node->destroy();
|
node->destroy();
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
||||||
// Convert ScalarImplicit to IntImplicit or FloatImplicit.
|
// Convert ScalarImplicit to IntImplicit or FloatImplicit.
|
||||||
TORCH_API void ConvertScalarImplicit(std::shared_ptr<Graph>& graph);
|
TORCH_API void ConvertScalarImplicit(std::shared_ptr<Graph> &graph);
|
||||||
|
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -1,49 +1,49 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::ostream& string_join(std::ostream& out, const std::vector<T>& v, const std::string& delimiter) {
|
std::ostream &string_join(std::ostream &out, const std::vector<T> &v,
|
||||||
size_t i = 0;
|
const std::string &delimiter) {
|
||||||
for (const T& e : v) {
|
size_t i = 0;
|
||||||
if ((i++) > 0) { out << delimiter; }
|
for (const T &e : v) {
|
||||||
out << e;
|
if ((i++) > 0) {
|
||||||
|
out << delimiter;
|
||||||
}
|
}
|
||||||
return out;
|
out << e;
|
||||||
|
}
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::string string_join(const std::vector<T>& v, const std::string& delimiter) {
|
std::string string_join(const std::vector<T> &v, const std::string &delimiter) {
|
||||||
std::ostringstream joined;
|
std::ostringstream joined;
|
||||||
string_join(joined, v, delimiter);
|
string_join(joined, v, delimiter);
|
||||||
return joined.str();
|
return joined.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::vector<std::string> string_split(
|
inline std::vector<std::string> string_split(const std::string &str,
|
||||||
const std::string& str,
|
const std::string &sep) {
|
||||||
const std::string& sep
|
std::vector<std::string> tokens;
|
||||||
) {
|
std::size_t pos1 = str.find_first_not_of(sep);
|
||||||
std::vector<std::string> tokens;
|
while (pos1 != std::string::npos) {
|
||||||
std::size_t pos1 = str.find_first_not_of(sep);
|
std::size_t pos2 = str.find_first_of(sep, pos1);
|
||||||
while (pos1 != std::string::npos) {
|
if (pos2 == std::string::npos) {
|
||||||
std::size_t pos2 = str.find_first_of(sep, pos1);
|
tokens.push_back(str.substr(pos1));
|
||||||
if (pos2 == std::string::npos) {
|
pos1 = pos2;
|
||||||
tokens.push_back(str.substr(pos1));
|
} else {
|
||||||
pos1 = pos2;
|
tokens.push_back(str.substr(pos1, pos2 - pos1));
|
||||||
} else {
|
pos1 = str.find_first_not_of(sep, pos2 + 1);
|
||||||
tokens.push_back(str.substr(pos1, pos2 - pos1));
|
|
||||||
pos1 = str.find_first_not_of(sep, pos2 + 1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return tokens;
|
}
|
||||||
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Returns true if str starts with prefix
|
* Returns true if str starts with prefix
|
||||||
*/
|
*/
|
||||||
inline bool startswith(const std::string& str, const std::string& prefix) {
|
inline bool startswith(const std::string &str, const std::string &prefix) {
|
||||||
return str.rfind(prefix, 0) == 0;
|
return str.rfind(prefix, 0) == 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,24 +6,25 @@
|
||||||
namespace sys_util {
|
namespace sys_util {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T GetEnv(const std::string& name, const T& default_value = T(0)) {
|
static T GetEnv(const std::string &name, const T &default_value = T(0)) {
|
||||||
const char* env = std::getenv(name.c_str());
|
const char *env = std::getenv(name.c_str());
|
||||||
if (!env) {
|
if (!env) {
|
||||||
return default_value;
|
return default_value;
|
||||||
}
|
}
|
||||||
return T(std::atoi(env));
|
return T(std::atoi(env));
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string GetEnvString(const std::string& name, const std::string& default_value) {
|
static std::string GetEnvString(const std::string &name,
|
||||||
const char* env = std::getenv(name.c_str());
|
const std::string &default_value) {
|
||||||
|
const char *env = std::getenv(name.c_str());
|
||||||
if (!env) {
|
if (!env) {
|
||||||
return default_value;
|
return default_value;
|
||||||
}
|
}
|
||||||
return std::string(env);
|
return std::string(env);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool GetEnvBool(const char* name, bool defval) {
|
static bool GetEnvBool(const char *name, bool defval) {
|
||||||
const char* env = std::getenv(name);
|
const char *env = std::getenv(name);
|
||||||
if (env == nullptr) {
|
if (env == nullptr) {
|
||||||
return defval;
|
return defval;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,84 +3,90 @@
|
||||||
#include "../generated/LazyIr.h"
|
#include "../generated/LazyIr.h"
|
||||||
#include "../mlir_node.h"
|
#include "../mlir_node.h"
|
||||||
|
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
bool is_detach_copy(const torch::lazy::Node* node) {
|
bool is_detach_copy(const torch::lazy::Node *node) {
|
||||||
return node && node->op() == torch::lazy::DetachCopy::ClassOpKind();
|
return node && node->op() == torch::lazy::DetachCopy::ClassOpKind();
|
||||||
}
|
}
|
||||||
bool is_detach_copy(const torch::lazy::Value& value) {
|
bool is_detach_copy(const torch::lazy::Value &value) {
|
||||||
return is_detach_copy(value.node.get());
|
return is_detach_copy(value.node.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) {
|
torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *node) {
|
||||||
if (!node) { return nullptr; }
|
if (!node) {
|
||||||
|
|
||||||
torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<torch::lazy::TorchMlirNode*>(node);
|
|
||||||
while(mlir_node && is_detach_copy(mlir_node)) {
|
|
||||||
mlir_node = mlir_node->mlir_node(0);
|
|
||||||
}
|
|
||||||
if (!mlir_node) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
return mlir_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) {
|
|
||||||
if (!node) { return nullptr; }
|
|
||||||
|
|
||||||
const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<const torch::lazy::TorchMlirNode*>(node);
|
|
||||||
while(mlir_node && is_detach_copy(mlir_node)) {
|
|
||||||
mlir_node = mlir_node->mlir_node(0);
|
|
||||||
}
|
|
||||||
if (!mlir_node) {
|
|
||||||
return node;
|
|
||||||
}
|
|
||||||
return mlir_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) {
|
|
||||||
if (!node) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
node = extract_non_detach_copy_node(node);
|
|
||||||
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
|
||||||
return dynamic_cast<torch::lazy::DeviceData*>(node);
|
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) {
|
|
||||||
if (!node) {
|
torch::lazy::TorchMlirNode *mlir_node =
|
||||||
return nullptr;
|
dynamic_cast<torch::lazy::TorchMlirNode *>(node);
|
||||||
}
|
while (mlir_node && is_detach_copy(mlir_node)) {
|
||||||
node = extract_non_detach_copy_node(node);
|
mlir_node = mlir_node->mlir_node(0);
|
||||||
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
}
|
||||||
return dynamic_cast<const torch::lazy::DeviceData*>(node);
|
if (!mlir_node) {
|
||||||
}
|
return node;
|
||||||
return nullptr;
|
}
|
||||||
}
|
return mlir_node;
|
||||||
torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
|
|
||||||
if (!value) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return device_data_cast(value.node.get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::lazy::DeviceData* device_data_cast(
|
const torch::lazy::Node *
|
||||||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device
|
extract_non_detach_copy_node(const torch::lazy::Node *node) {
|
||||||
) {
|
if (!node) {
|
||||||
if (!device) {
|
|
||||||
device = torch::lazy::GetBackendDevice(tensor);
|
|
||||||
}
|
|
||||||
TORCH_CHECK(device);
|
|
||||||
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
|
|
||||||
if (lazy_tensor) {
|
|
||||||
return device_data_cast(lazy_tensor->GetIrValue());
|
|
||||||
}
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
const torch::lazy::TorchMlirNode *mlir_node =
|
||||||
|
dynamic_cast<const torch::lazy::TorchMlirNode *>(node);
|
||||||
|
while (mlir_node && is_detach_copy(mlir_node)) {
|
||||||
|
mlir_node = mlir_node->mlir_node(0);
|
||||||
|
}
|
||||||
|
if (!mlir_node) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
return mlir_node;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lazy
|
torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *node) {
|
||||||
} // namespace torch
|
if (!node) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
node = extract_non_detach_copy_node(node);
|
||||||
|
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||||
|
return dynamic_cast<torch::lazy::DeviceData *>(node);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
const torch::lazy::DeviceData *device_data_cast(const torch::lazy::Node *node) {
|
||||||
|
if (!node) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
node = extract_non_detach_copy_node(node);
|
||||||
|
if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) {
|
||||||
|
return dynamic_cast<const torch::lazy::DeviceData *>(node);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
torch::lazy::DeviceData *device_data_cast(const torch::lazy::Value &value) {
|
||||||
|
if (!value) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return device_data_cast(value.node.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::lazy::DeviceData *
|
||||||
|
device_data_cast(const at::Tensor &tensor,
|
||||||
|
c10::optional<torch::lazy::BackendDevice> device) {
|
||||||
|
if (!device) {
|
||||||
|
device = torch::lazy::GetBackendDevice(tensor);
|
||||||
|
}
|
||||||
|
TORCH_CHECK(device);
|
||||||
|
torch::lazy::LazyTensorPtr lazy_tensor =
|
||||||
|
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
|
||||||
|
if (lazy_tensor) {
|
||||||
|
return device_data_cast(lazy_tensor->GetIrValue());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace lazy
|
||||||
|
} // namespace torch
|
||||||
|
|
|
@ -8,18 +8,21 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace lazy {
|
namespace lazy {
|
||||||
|
|
||||||
TORCH_API bool is_detach_copy(const torch::lazy::Node*);
|
TORCH_API bool is_detach_copy(const torch::lazy::Node *);
|
||||||
TORCH_API bool is_detach_copy(const torch::lazy::Value&);
|
TORCH_API bool is_detach_copy(const torch::lazy::Value &);
|
||||||
|
|
||||||
TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*);
|
TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *);
|
||||||
TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*);
|
TORCH_API const torch::lazy::Node *
|
||||||
|
extract_non_detach_copy_node(const torch::lazy::Node *);
|
||||||
|
|
||||||
TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*);
|
TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *);
|
||||||
TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*);
|
TORCH_API const torch::lazy::DeviceData *
|
||||||
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
|
device_data_cast(const torch::lazy::Node *);
|
||||||
TORCH_API torch::lazy::DeviceData* device_data_cast(
|
TORCH_API torch::lazy::DeviceData *
|
||||||
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
|
device_data_cast(const torch::lazy::Value &value);
|
||||||
);
|
TORCH_API torch::lazy::DeviceData *device_data_cast(
|
||||||
|
const at::Tensor &tensor,
|
||||||
|
c10::optional<torch::lazy::BackendDevice> device = c10::nullopt);
|
||||||
|
|
||||||
} // namespace lazy
|
} // namespace lazy
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace lazy {
|
||||||
|
|
||||||
/// Returns true if a string begins with another.
|
/// Returns true if a string begins with another.
|
||||||
inline bool beginswith(const std::string& s, const std::string& t) {
|
inline bool beginswith(const std::string& s, const std::string& t) {
|
||||||
return s.size() >= t.size() && s.compare(0, t.size(), t) == 0;
|
return s.size() >= t.size() && s.compare(0, t.size(), t) == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
struct ReferenceLazyBackendDeviceType : public BackendDeviceType {
|
||||||
|
@ -73,10 +73,8 @@ public:
|
||||||
// Vendor backend specific lowering can be exec here before returning.
|
// Vendor backend specific lowering can be exec here before returning.
|
||||||
for (const auto& instance : instances) {
|
for (const auto& instance : instances) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
instance->in_mark_step,
|
instance->in_mark_step, "Compile outside of mark step:\n",
|
||||||
"Compile outside of mark step:\n",
|
GetComputationBackendText(instance));
|
||||||
GetComputationBackendText(instance)
|
|
||||||
);
|
|
||||||
// Store computation instance for external access after compilation.
|
// Store computation instance for external access after compilation.
|
||||||
GetLatestComputation() = instance;
|
GetLatestComputation() = instance;
|
||||||
}
|
}
|
||||||
|
@ -114,16 +112,17 @@ public:
|
||||||
// Convert any lazy devices to cpu devices to ensure
|
// Convert any lazy devices to cpu devices to ensure
|
||||||
// that the values are actually computed
|
// that the values are actually computed
|
||||||
if (node->outputs().size() == 1 &&
|
if (node->outputs().size() == 1 &&
|
||||||
node->output()->type()->kind() ==
|
node->output()->type()->kind() == c10::TypeKind::DeviceObjType) {
|
||||||
c10::TypeKind::DeviceObjType) {
|
auto value_sym = torch::jit::Symbol::attr("value");
|
||||||
auto value_sym = torch::jit::Symbol::attr("value");
|
TORCH_CHECK(
|
||||||
TORCH_CHECK(node->hasAttribute(value_sym),
|
node->hasAttribute(value_sym),
|
||||||
"Expected node to have 'value' attribute.");
|
"Expected node to have 'value' attribute.");
|
||||||
TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s,
|
TORCH_CHECK(
|
||||||
"Expected 'value' attribute to be a string.");
|
node->kindOf(value_sym) == torch::jit::AttributeKind::s,
|
||||||
if (beginswith(node->s(value_sym), "lazy")) {
|
"Expected 'value' attribute to be a string.");
|
||||||
node->s_(value_sym, "cpu");
|
if (beginswith(node->s(value_sym), "lazy")) {
|
||||||
}
|
node->s_(value_sym, "cpu");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +131,8 @@ public:
|
||||||
for (const auto& argument : arguments) {
|
for (const auto& argument : arguments) {
|
||||||
const auto mlir_data =
|
const auto mlir_data =
|
||||||
std::static_pointer_cast<TorchMlirBackendData>(argument);
|
std::static_pointer_cast<TorchMlirBackendData>(argument);
|
||||||
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
|
auto* info =
|
||||||
|
dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
|
||||||
TORCH_CHECK(info);
|
TORCH_CHECK(info);
|
||||||
if (info->scalar.has_value()) {
|
if (info->scalar.has_value()) {
|
||||||
stack.emplace_back(info->scalar.value());
|
stack.emplace_back(info->scalar.value());
|
||||||
|
|
|
@ -8,8 +8,8 @@
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "torch/csrc/jit/python/pybind.h"
|
#include "torch/csrc/jit/python/pybind.h"
|
||||||
#include "torch/csrc/lazy/core/config.h"
|
|
||||||
#include "torch/csrc/lazy/backend/backend_interface.h"
|
#include "torch/csrc/lazy/backend/backend_interface.h"
|
||||||
|
#include "torch/csrc/lazy/core/config.h"
|
||||||
|
|
||||||
#include <base_lazy_backend/mlir_lowering_context.h>
|
#include <base_lazy_backend/mlir_lowering_context.h>
|
||||||
#include <base_lazy_backend/utils/string_utils.h>
|
#include <base_lazy_backend/utils/string_utils.h>
|
||||||
|
@ -56,8 +56,8 @@ void Initialize() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ir_debug) {
|
if (ir_debug) {
|
||||||
FLAGS_torch_lazy_ir_debug = true;
|
FLAGS_torch_lazy_ir_debug = true;
|
||||||
std::cout << "Enabled lazy tensor IR debugging." << std::endl;
|
std::cout << "Enabled lazy tensor IR debugging." << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,15 +82,17 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) {
|
||||||
torch::lazy::GetLatestComputation().get());
|
torch::lazy::GetLatestComputation().get());
|
||||||
return py::cast(computation);
|
return py::cast(computation);
|
||||||
});
|
});
|
||||||
m.def("set_parameter_name",
|
m.def(
|
||||||
[](const at::Tensor& tensor, const std::string& name) -> bool {
|
"set_parameter_name",
|
||||||
torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor);
|
[](const at::Tensor& tensor, const std::string& name) -> bool {
|
||||||
if (ir_node) {
|
torch::lazy::DeviceData* ir_node =
|
||||||
ir_node->SetName(name);
|
torch::lazy::device_data_cast(tensor);
|
||||||
return true;
|
if (ir_node) {
|
||||||
}
|
ir_node->SetName(name);
|
||||||
return false;
|
return true;
|
||||||
});
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
m.def("_initialize", []() {
|
m.def("_initialize", []() {
|
||||||
NoGilSection gil;
|
NoGilSection gil;
|
||||||
Initialize();
|
Initialize();
|
||||||
|
|
Loading…
Reference in New Issue