Clang format refresh (#2812)

After noticing a number of commits with unrelated formatting changes,
I think something was changed with clang-format at one point and we're
seeing a number of unrelated changes. Doing a refresh can help avoid
this.

The changes made here came from
```
find lib -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find include -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
find projects -iname *.h -o -iname *.cpp  | xargs clang-format -i --style=llvm
```
pull/2823/head
Quinn Dawkins 2024-01-29 12:59:33 -05:00 committed by GitHub
parent d3fd754b93
commit 494089d53d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
81 changed files with 1972 additions and 1815 deletions

View File

@ -10,9 +10,9 @@
#ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"

View File

@ -97,7 +97,8 @@ struct OpBinder {
return success();
}
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) {
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx,
int64_t idx) {
if (idx >= op->getNumResults())
return failure();
auto t = toValidTensorType(op->getResult(idx).getType());

View File

@ -37,8 +37,8 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op,
return CreateOpAndInfer<TosaOpT>(rewriter, op->getLoc(), outType, lhs, rhs);
}
// This specialization is for Div op. Unlike other binary ops, it doesn't support
// floating type.
// This specialization is for Div op. Unlike other binary ops, it doesn't
// support floating type.
template <>
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
Operation *op, TensorType outType,
@ -53,9 +53,8 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
// Lowers torch.aten.Gather operators to a sequence of TOSA ops.
// Revised from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
Operation *op, Type out_type,
Value params_value,
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type out_type, Value params_value,
Value indices_value);
std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
@ -63,7 +62,6 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
Value paramsValue, Value indicesValue,
Value fillValues);
// Lowers ReduceAll to a sequence of TOSA ops.
std::optional<Value>
convertReduceAllOp(PatternRewriter &rewriter, Operation *op,

View File

@ -36,8 +36,7 @@ class HasValueSemantics
// This is a weaker form of HasValueSemantics, since that trait also requires no
// aliasing. That is, HasValueSemantics implies this trait.
template <typename ConcreteType>
class ReadOnly
: public ::mlir::OpTrait::TraitBase<ConcreteType, ReadOnly> {};
class ReadOnly : public ::mlir::OpTrait::TraitBase<ConcreteType, ReadOnly> {};
// If a Torch op has this trait, it means that the op is a "trailing underscore"
// op variant that performs an in-place operation on its first argument. These
@ -62,7 +61,8 @@ class AllowsTypeRefinement
// by the IValue importer.
template <typename ConcreteType>
class AllowedInModuleInitializer
: public ::mlir::OpTrait::TraitBase<ConcreteType, AllowedInModuleInitializer> {};
: public ::mlir::OpTrait::TraitBase<ConcreteType,
AllowedInModuleInitializer> {};
} // namespace OpTrait
} // namespace Torch

View File

@ -61,7 +61,8 @@ struct TorchLoweringPipelineOptions
Option<std::string> extraLibrary{
*this, "extra-library",
llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")};
llvm::cl::desc("Filename of MLIR module for splicing into the abstract "
"interpretation library.")};
};
/// Creates a pipeline that lowers the object graph IR that is produced by
@ -125,8 +126,7 @@ createSimplifyDtypeCalculationsPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createDropAbstractInterpCalculationsPass();
std::unique_ptr<OperationPass<ModuleOp>>
createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>> createEraseModuleInitializerPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerToBackendContractPass(int maxIterations, bool decompose,

View File

@ -140,12 +140,7 @@ enum Reduction { None, Mean, Sum, END };
// Source:
// https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h
//===----------------------------------------------------------------------===//
enum MemoryFormat {
Contiguous,
Preserve,
ChannelsLast,
ChannelsLast3d
};
enum MemoryFormat { Contiguous, Preserve, ChannelsLast, ChannelsLast3d };
//===----------------------------------------------------------------------===//
// Possible values for `layout` argument in PyTorch ops that support it.

View File

@ -121,8 +121,7 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter,
// Helper to create a tensor filled with the given scalar. Scalar would be
// converted the to the element type of the given tensor type.
Value createInitTensor(PatternRewriter &rewriter, Location loc,
BaseTensorType resultType, Value scalar,
Value sizeList);
BaseTensorType resultType, Value scalar, Value sizeList);
// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar`
// would be converted to the element type of the given `inputType`.

View File

@ -9,7 +9,8 @@
#include "torch-mlir-c/Dialects.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "mlir/CAPI/Registration.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, mlir::torch::Torch::TorchDialect)
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch,
mlir::torch::Torch::TorchDialect)

View File

@ -30,6 +30,4 @@ namespace {
#include "torch-mlir/Conversion/Passes.h.inc"
} // end namespace
void mlir::torch::registerConversionPasses() {
::registerPasses();
}
void mlir::torch::registerConversionPasses() { ::registerPasses(); }

View File

@ -82,7 +82,8 @@ public:
// temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
globalVar = rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
globalVar =
rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar);

View File

@ -29,7 +29,8 @@ using namespace mlir::torch::onnx_c;
// thing here, so we simplify.
void mlir::torch::onnx_c::populateDefaultDomainGtoP(
OnnxCustomOpConversionPattern &patterns) {
patterns.onOp("HardSigmoid", 6,
patterns.onOp(
"HardSigmoid", 6,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value tensorOperand;
@ -40,7 +41,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.tensorResultType(resultType))
return failure();
// HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta))
// HardSigmoid computes the following expression:
// max(0, min(1, alpha * x + beta))
Value constAlpha = rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
rewriter.getF64FloatAttr(alpha));
@ -51,7 +53,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
// Expression: alpha * x + beta
Value alpha_x_plus_beta = rewriter.create<Torch::AtenAddScalarOp>(
binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha);
binder.getLoc(), resultType, tensorOperand, constBeta,
/*alpha=*/constAlpha);
// Expression: min(1, alpha * x + beta)
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
@ -332,13 +335,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, lhs, rhs);
return success();
});
patterns.onOp("Max", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
patterns.onOp(
"Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
operands.size() == 0) {
binder.tensorResultType(resultType) || operands.size() == 0) {
return failure();
}
Value result = operands[0];
@ -349,13 +351,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.replaceOp(binder.op, result.getDefiningOp());
return success();
});
patterns.onOp("Min", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
patterns.onOp(
"Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
operands.size() == 0) {
binder.tensorResultType(resultType) || operands.size() == 0) {
return failure();
}
Value result = operands[0];
@ -363,8 +364,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
result = rewriter.create<Torch::AtenMinimumOp>(
binder.getLoc(), resultType, result, operands[i]);
}
rewriter.replaceOp(
binder.op, result.getDefiningOp());
rewriter.replaceOp(binder.op, result.getDefiningOp());
return success();
});
patterns.onOp("Neg", 1,
@ -693,7 +693,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstStrides);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstCeilMode = cstFalse;
Value cstCountIncludePad = cstFalse;
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

View File

@ -42,7 +42,8 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
OnnxCustomOpConversionPattern &patterns) {
patterns.onOp("QuantizeLinear", 1,
patterns.onOp(
"QuantizeLinear", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
@ -56,11 +57,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
if (!scaleTy || !scaleTy.hasSizes())
return rewriter.notifyMatchFailure(binder.op,
"requires known rank");
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
if (!resultType.hasDtype())
return rewriter.notifyMatchFailure(
binder.op, "requires known result dtype");
return rewriter.notifyMatchFailure(binder.op,
"requires known result dtype");
if (scaleTy.getSizes().size() == 0) {
Type qTy = resultType.getDtype();
@ -72,21 +72,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
} else if (qTy.isSignedInteger(32)) {
qTy = rewriter.getType<Torch::QInt32Type>();
} else {
return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype");
return rewriter.notifyMatchFailure(binder.op,
"unsupported result dtype");
}
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(resultType.getOptionalSizes(), qTy);
auto qTensorTy = rewriter.getType<Torch::ValueTensorType>(
resultType.getOptionalSizes(), qTy);
auto torchqTy = Torch::getScalarTypeForType(qTy);
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast<int64_t>(torchqTy)));
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
static_cast<int64_t>(torchqTy)));
scale = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
scale = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), scale);
zeropoint = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), zeropoint);
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(binder.op, resultType, quantize);
auto quantize = rewriter.create<Torch::AtenQuantizePerTensorOp>(
binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst);
rewriter.replaceOpWithNewOp<Torch::AtenIntReprOp>(
binder.op, resultType, quantize);
return success();
}

View File

@ -43,7 +43,8 @@ public:
LogicalResult
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rank = rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
auto rank =
rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, getTypeConverter()->convertType(op.getType()), rank);
return success();
@ -74,7 +75,8 @@ public:
matchAndRewrite(AtenOp op,
typename OpConversionPattern<AtenOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(), adaptor.getB());
rewriter.template replaceOpWithNewOp<BinOp>(op, adaptor.getA(),
adaptor.getB());
return success();
}
};
@ -112,10 +114,10 @@ public:
typename OpConversionPattern<AtenDivIntOp>::OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value a =
convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type());
Value b =
convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type());
Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(),
rewriter.getF64Type());
Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(),
rewriter.getF64Type());
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, a, b);
return success();
}
@ -178,13 +180,14 @@ public:
auto shapedType =
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
auto rawData = elements.getRawData();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
shapedType, rawData);
DenseElementsAttr newAttr =
DenseElementsAttr::getFromRawBuffer(shapedType, rawData);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
return success();
}
}
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto elements =
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
Type builtinTensorElemTy =
@ -360,7 +363,8 @@ public:
// -----------------------------------------------------------------------------
namespace {
class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith> {
class ConvertTorchToArith
: public ConvertTorchToArithBase<ConvertTorchToArith> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<func::FuncDialect>();

View File

@ -110,22 +110,32 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
// Example:
// input = tensor([[[0., 1., 2., 3.],
// [4., 5., 6., 7.]]])
// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1
// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
// torch.ops.aten.reflection_pad1d(input, (3,1));
// padding_left = 3,
// padding_right = 1
// output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.],
// [7., 6., 5., 4., 5., 6., 7., 6.]]])
// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension
// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension.
// The last dimension of the result tensor should be last dimension of input tensor +
// left padding size + right padding size. INitialize result tensor to all zeros
// b) Setup affine map to take slice from input tensor of size left padding starting from
// second column onwards as first column is reflection boundary
// Checks: 1) Each of padding_left and padding_right must be non-negative and
// less than the size of the last dimension.
// Implementation: a) Construct a result tensor of
// shape of input tensor except for the last dimension.
// The last dimension of the result tensor should be last
// dimension of input tensor + left padding size + right
// padding size. Initialize result tensor to all zeros
// b) Setup affine map to take slice from input tensor of size
// left padding starting from
// second column onwards as first column is reflection
// boundary
// c) Reflect the affine map to have resultant slice reflected
// d) Take the slice and write from begining in result tensor
// e) write the original tensor next into result tensor
// f) Setup affine map to take slice from input tensor of right padding size ending
// at second last column as last column is reflection boundary for right padding
// f) Setup affine map to take slice from input tensor of right
// padding size ending
// at second last column as last column is reflection
// boundary for right padding
// g) Reflect the affine map to have resultant slice reflected
// h) Take the slice and write from left padding size + orignal tensor last dim size
// h) Take the slice and write from left padding size + orignal
// tensor last dim size
// into result tensor
// Uses the ideas/code used for AtenReflectionPad2dOp
namespace {
@ -165,43 +175,56 @@ public:
Value zero = getConstant(rewriter, loc, 0, indexType);
Value one = getConstant(rewriter, loc, 1, indexType);
auto inputType = llvm::cast<RankedTensorType>(input.getType());
auto outputType = llvm::cast<RankedTensorType>(getTypeConverter()->convertType(op->getResult(0).getType()));
auto outputType = llvm::cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
unsigned numDims = inputType.getRank();
assert(numDims >= 2 && "Not enough input dimensions");
int64_t lastDim = numDims - 1;
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4
Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2,
// inputShape[2] will give 4
Value tileWidth[3], extractOffset[3], insertOffset[3];
tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
tileWidth[PAD_LEFT] =
getConstant(rewriter, loc, padInts[PAD_LEFT], indexType);
tileWidth[PAD_RIGHT] =
getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType);
tileWidth[PAD_CENTER] = lastDimSize;
extractOffset[PAD_LEFT] = one;
// for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right
// lasDimSize - (tileWidth[PAD_RIGHT] + one)
extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
// The offset for the right hand padding "bar" is:
// [right] lastDimSize - (tileWidth[PAD_RIGHT] + one)
extractOffset[PAD_RIGHT] =
createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one));
extractOffset[PAD_CENTER] = zero;
insertOffset[PAD_LEFT] = zero;
insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]);
insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT];
SmallVector<Value> resultShape{inputShape};
// Result's last dimension will have shape lastDimSize + left padding size + right padding size
resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType());
// Result's last dimension will have size:
// lastDimSize + left padding size + right padding size
resultShape[lastDim] =
createIAdd(resultShape[lastDim],
createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT]));
Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape,
inputType.getElementType());
// Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor
// for which the corresponding dimension has a statically known size
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) {
// Helper to reflect/reverse the i-th dimension of an affine map without
// symbols. This only works if applied on a tensor for which the
// corresponding dimension has a statically known size
auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i,
int64_t size) {
AffineExpr d = map.getResult(i);
return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3
return map.replace(d, size - d - 1, numDims,
0); // left reflect for (3,1) on input shape (1,2,4).
// size = 3, lastDim=2, numDims=3
};
SmallVector<utils::IteratorType> iteratorTypes{numDims, utils::IteratorType::parallel};
SmallVector<utils::IteratorType> iteratorTypes{
numDims, utils::IteratorType::parallel};
auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context);
SmallVector<Value> allOneStrides(numDims, one);
@ -214,22 +237,26 @@ public:
Value tile = rewriter.create<tensor::ExtractSliceOp>(
loc, input, extractOffsets, extractShape, allOneStrides);
auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context);
// Setup the affine map function to resverse the tile along the horizontal for left and right slices
// Setup the affine map function to resverse the tile along the horizontal
// for left and right slices
if (padPosition < PAD_CENTER) {
inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]);
// Take reflected slice as per inputMap
tile = rewriter.create<linalg::GenericOp>(loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile = rewriter
.create<linalg::GenericOp>(
loc, llvm::cast<RankedTensorType>(tile.getType()), tile,
tile, ArrayRef({inputMap, idMap}), iteratorTypes,
[](OpBuilder &b, Location nestedLoc, ValueRange args) {
b.create<linalg::YieldOp>(nestedLoc, args[0]);
}).getResult(0);
})
.getResult(0);
}
// Insert the tile in the resultTensor
SmallVector<Value> insertOffsets(numDims, zero);
insertOffsets[lastDim] = insertOffset[padPosition];
resultTensor = rewriter.create<tensor::InsertSliceOp>(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
resultTensor = rewriter.create<tensor::InsertSliceOp>(
loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides);
};
if (padInts[PAD_LEFT] > 0)
@ -242,7 +269,7 @@ public:
return success();
}
};
}
} // namespace
namespace {

View File

@ -79,7 +79,8 @@ public:
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
int64_t inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
@ -248,9 +249,9 @@ public:
}
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag.");
return rewriter.notifyMatchFailure(op,
"Unimplemented: Mean and Max mode are "
"not supported yet for EmbeddingBag.");
}
bool isSparse;
@ -351,10 +352,10 @@ public:
Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
Value indexIToInt = castIndexToInt64(b, loc, indexI);
Value one = getConstant(
b, loc, 1,
mlir::IntegerType::get(getContext(), 64,
IntegerType::Signless));
Value one =
getConstant(b, loc, 1,
mlir::IntegerType::get(
getContext(), 64, IntegerType::Signless));
Value offsetIndexPlusOneInt =
b.create<arith::AddIOp>(loc, indexIToInt, one);
@ -393,14 +394,13 @@ public:
castIntToIndex(b, loc, indexInIndices));
indexIntoWeight.push_back(
b.create<linalg::IndexOp>(loc, /*value=*/2));
Value weightElem = b.create<tensor::ExtractOp>(
loc, weight, indexIntoWeight);
Value weightElem =
b.create<tensor::ExtractOp>(loc, weight, indexIntoWeight);
Value addResult = b.create<arith::AddFOp>(loc, weightElem,
initTensorElem);
Value select =
b.create<arith::SelectOp>(loc, indicesIndexWithinBounds,
addResult, initTensorElem);
Value addResult =
b.create<arith::AddFOp>(loc, weightElem, initTensorElem);
Value select = b.create<arith::SelectOp>(
loc, indicesIndexWithinBounds, addResult, initTensorElem);
b.create<linalg::YieldOp>(loc, select);
})
.getResult(0);
@ -552,7 +552,8 @@ static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index,
// e.g. x: [2, 3]
// x[[4], [6, 1]] -> x[6, 4]
namespace {
class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
class ConvertAtenIndexTensorHackedTwinOp
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult

View File

@ -165,7 +165,8 @@ public:
Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Value self = adaptor.getSelf();
auto selfRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
auto selfRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
Type elementType =
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
Value c1 =
@ -535,7 +536,8 @@ public:
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
Type newResultType = getTypeConverter()->convertType(op.getType());
Type resultElementType = newResultType.cast<RankedTensorType>().getElementType();
Type resultElementType =
newResultType.cast<RankedTensorType>().getElementType();
Type lhsElementType = lhsType.cast<RankedTensorType>().getElementType();
Type rhsElementType = rhsType.cast<RankedTensorType>().getElementType();
@ -547,13 +549,15 @@ public:
// Convert the inputs element type equivalent to the result' element type.
if (lhsElementType != rhsElementType) {
if (lhsElementType != resultElementType) {
// True if the lhs element type is not equal to the result' element type.
lhs = torch_to_linalg::convertTensorToElementType(
rewriter, loc, lhs, resultElementType);
// True if the lhs element type is not equal to the result' element
// type.
lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs,
resultElementType);
} else {
// True if the rhs element type is not equal to the result' element type.
rhs = torch_to_linalg::convertTensorToElementType(
rewriter, loc, rhs, resultElementType);
// True if the rhs element type is not equal to the result' element
// type.
rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs,
resultElementType);
}
}
@ -571,7 +575,8 @@ public:
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
Value initTensor0 = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType);
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2},
resultElementType);
Value bmm =
rewriter
@ -634,7 +639,8 @@ public:
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
SmallVector<int64_t> dilationInts;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts)))
if (!matchPattern(op.getDilation(),
m_TorchListOfConstantInts(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
@ -838,8 +844,10 @@ public:
Value conv;
// the code so far is able to respect all numSpacialDims
// the code below this point is numSpacialDims specific and groupSize specific
// TODO: factor out the above code into a helper function, and then separate convolution into:
// the code below this point is numSpacialDims specific and groupSize
// specific
// TODO: factor out the above code into a helper function, and then separate
// convolution into:
// - grouped 1d-3d
// - ungrouped 1d-3d
if (groupSize == 1) {
@ -854,19 +862,19 @@ public:
.getResult(0);
break;
case 2:
conv =
rewriter
conv = rewriter
.create<linalg::Conv2DNchwFchwOp>(
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
outputTensor, stridesAttr, dilationAttr)
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
case 3:
conv =
rewriter
conv = rewriter
.create<linalg::Conv3DNcdhwFcdhwOp>(
loc, outputTensor.getType(), ValueRange{paddedInput, weight},
outputTensor, stridesAttr, dilationAttr)
loc, outputTensor.getType(),
ValueRange{paddedInput, weight}, outputTensor,
stridesAttr, dilationAttr)
.getResult(0);
break;
default:

View File

@ -194,7 +194,6 @@ public:
};
} // namespace
void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {

View File

@ -144,8 +144,7 @@ public:
}
Value filledTensorVal =
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal)
.result();
rewriter.create<linalg::FillOp>(loc, fillValue, initTensorVal).result();
// Create the affine expressions that will be used to
// iterate over the input and output tensors.
@ -220,8 +219,8 @@ public:
});
// This cast is required to fix the shape in the case of keepDim=True
Value valuesCast = rewriter.create<tensor::CastOp>(
loc, valResultType, linalgOp.getResult(0));
Value valuesCast = rewriter.create<tensor::CastOp>(loc, valResultType,
linalgOp.getResult(0));
Value idxCast = rewriter.create<tensor::CastOp>(loc, idxResultType,
linalgOp.getResult(1));
rewriter.replaceOp(op, {valuesCast, idxCast});
@ -345,7 +344,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc,
Value self = convertScalarToDtype(b, loc, elem, resultElementType);
auto abs = b.create<math::AbsFOp>(loc, self);
AtenLinalgVectorNormOp::Adaptor adaptor(operands);
Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
Value ord =
convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType);
auto pow = b.create<math::PowFOp>(loc, abs, ord);
return b.create<arith::AddFOp>(loc, pow, result);
} else if (isa<AtenFrobeniusNormDimOp>(op)) {
@ -427,8 +427,8 @@ private:
opInfo.tensorOperand = operands[0];
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
// `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the
// input tensor.
// `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the
// dimensions of the input tensor.
for (int64_t i = 0; i < inputType.getRank(); i++)
opInfo.dimSet.insert(i);

View File

@ -120,13 +120,18 @@ namespace {
Value vDimSize = inputShape[vDim];
enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 };
enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, };
enum tileVLoc {
TOP = 0,
VCENTER = 2,
BOTTOM = 1,
};
// vTile denotes the vertical size of the tile
// hTile denotes the horizontal size of the tile
// The padding results are composed of following tiles:
// vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT]
// vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], vTile[VCENTER]hTile[RIGHT]
// vTile[BOTTOM]hTile[LEFT], vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT]
// vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER],
// vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT],
// vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT]
// vTile[VCENTER]hTile[HCENTER] is the original input tensor
Type indexType = rewriter.getIndexType();
Value vTile[3];
@ -215,16 +220,19 @@ namespace {
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
lowPadding[2] = padInts[2];
vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue);
}
if (hasBottomPadding) {
Value bottomLeftValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
Value bottomLeftValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero});
// pad vLeftSlice at the bottom
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
highPadding[2] = padInts[3];
vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
vLeftSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue);
}
for (auto i = 0; i < padInts[0]; ++i) {
tensorsLeft.push_back(vLeftSlice);
@ -256,27 +264,34 @@ namespace {
loc, input, extractOffsetsRight, extractShapeLR, allOneStrides);
Value vRightSlice = vCenterRightSlice;
if (hasTopPadding) {
Value topRightValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
Value topRightValue = rewriter.create<tensor::ExtractOp>(
loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne});
// pad vCenterRightSlice on the top
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
lowPadding[2] = padInts[2];
vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue);
}
if (hasBottomPadding) {
Value bottomRightValue = rewriter.create<tensor::ExtractOp> (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
Value bottomRightValue = rewriter.create<tensor::ExtractOp>(
loc, input,
ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne});
// Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom.
SmallVector<int64_t> lowPadding(4, 0);
SmallVector<int64_t> highPadding(4, 0);
highPadding[2] = padInts[3];
vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue);
vRightSlice = torch_to_linalg::getPaddedTensor(
op, rewriter, vRightSlice, lowPadding, highPadding,
bottomRightValue);
}
for (auto i = 0; i < padInts[1]; ++i) {
tensorsRight.push_back(vRightSlice);
}
Value rightPadTile = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
Value rightPadTile =
rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRight);
tensorsRes.push_back(rightPadTile);
}
Value resTensor = rewriter.create<tensor::ConcatOp>(loc, 3, tensorsRes);
@ -285,7 +300,7 @@ namespace {
return success();
}
};
}
} // namespace
namespace {
// Converts constant tensor allocation like ops.
@ -348,8 +363,8 @@ public:
// Create an uninitialized tensor of `resultSize` shape and fill it with
// value `fillVal`.
Value constVal = getConstant(rewriter, loc, fillVal, resultElementType);
Value outputTensor =
createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal);
Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex,
resultElementType, constVal);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, outputTensor);
return success();
}
@ -384,7 +399,8 @@ public:
// Only `none`, `contiguous` and `preserve` memory_format is supported.
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
int64_t memoryFormat;
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
if (!matchPattern(op.getMemoryFormat(),
m_TorchConstantInt(&memoryFormat)))
return rewriter.notifyMatchFailure(
op, "unimplemented: the memory format should be specified in "
"an integer constant");
@ -495,7 +511,8 @@ public:
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type dtype = resultType.getElementType();
Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
Value start =
convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype);
Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype);
Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype);

View File

@ -429,7 +429,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
if (isa<AtenIsinfOp>(op)) {
Value abs = b.create<math::AbsFOp>(loc, payloadArgs[0]);
Value infinity = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
loc,
b.getFloatAttr(abs.getType(), std::numeric_limits<double>::infinity()));
return createEqual(b, loc, abs.getType(), abs, infinity);
}
if (isa<AtenSigmoidOp>(op)) {

View File

@ -7,13 +7,13 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "../PassDetail.h"
#include "PopulatePatterns.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"

View File

@ -923,8 +923,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
op.getA().getType().template cast<BaseTensorType>().getDtype();
Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType());
auto result =
rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
rewriter.replaceOp(
op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype));
@ -1797,8 +1796,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
#define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, \
context)
patterns.add<ConvertAtenTensorToScalarLikeOp<AtenOp>>(typeConverter, context)
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp);
INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp);

View File

@ -35,7 +35,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy,
PatternRewriter &rewriter) {
auto constType = RankedTensorType::get({}, elementTy);
// Avg pooling
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp, AtenCumsumOp>(op)) {
if (isa<AtenAvgPool1dOp, AtenAdaptiveAvgPool2dOp, AtenAvgPool2dOp,
AtenCumsumOp>(op)) {
if (elementTy.isa<mlir::FloatType>()) {
auto constAttr = DenseElementsAttr::get(
constType, {APFloat::getZero(
@ -373,7 +374,6 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
return success();
}
namespace {
template <typename AtenOpT, int Dim>
class ConvertAtenAvgPoolOp : public ConvertAtenOp<AtenOpT> {
@ -392,7 +392,6 @@ public:
.template cast<RankedTensorType>();
auto outShape = outTy.getShape();
if (inputRank <= Dim) {
return op.emitError(
"avg_pooling1d/2d only supports inputs with rank higher than 1/2");
@ -407,15 +406,16 @@ public:
op, "non-const int kernel size unsupported!");
}
if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) {
return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!");
return rewriter.notifyMatchFailure(op,
"non-const int stride unsupported!");
}
if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) {
return rewriter.notifyMatchFailure(op,
"non-const int padding unsupported!");
}
if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) {
return rewriter.notifyMatchFailure(op,
"non-const bool ceil_mode unsupported!");
return rewriter.notifyMatchFailure(
op, "non-const bool ceil_mode unsupported!");
}
if (!(matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)))) {
@ -450,10 +450,12 @@ public:
stablehloPadding[stablehloPadding.size() - 1] = padding[1];
}
Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
Value initVal =
createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter);
DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(stablehloKernelSize.size())},
RankedTensorType::get(
{static_cast<int64_t>(stablehloKernelSize.size())},
rewriter.getI64Type()),
stablehloKernelSize);
DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get(
@ -518,8 +520,8 @@ public:
windowSizeConst =
hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy);
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
auto inputShapeVec =
*hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits);
auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input,
options.dimSizeIndexBits);
auto inputShapeTensor = rewriter.create<mlir::tensor::FromElementsOp>(
op->getLoc(), inputShapeVec);
@ -555,12 +557,9 @@ public:
rewriter.replaceOpWithNewOp<stablehlo::DivOp>(
op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0));
return success();
}
};
}
} // namespace
// AtenCumsumOp
template <>
@ -662,8 +661,8 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality(
patterns.add<ConvertAtenOp<AtenCumsumOp>>(typeConverter, context, options);
#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>( \
typeConverter, context, options)
patterns.add<ConvertAtenAvgPoolOp<AtenOp, Dim>>(typeConverter, context, \
options)
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1);
INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2);
#undef INSERT_ATEN_AVGPOOL_PATTERN

View File

@ -16,13 +16,13 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
using namespace mlir;
using namespace mlir::torch;

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
@ -22,7 +23,6 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h"
#include <numeric>
using namespace mlir;
@ -403,7 +403,8 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("dim must be a Scalar constant");
int64_t inputRank = adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
int64_t inputRank =
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
dim = toPositiveDim(dim, inputRank + 1);
if (!isValidDim(dim, inputRank + 1))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

View File

@ -210,9 +210,9 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
// Lowers Gather operators to a sequence of TOSA ops.
// taken from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter,
Operation *op, Type outType,
Value paramsValue, Value indicesValue) {
std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
Type outType, Value paramsValue,
Value indicesValue) {
auto resultType = outType.dyn_cast<ShapedType>();
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
@ -683,7 +683,6 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
.getResult();
}
// Common function for lowering reduce operations to TOSA ops.
template <typename T>
std::optional<Value> convertReduceOpCommon(
@ -721,9 +720,8 @@ std::optional<Value> convertReduceOpCommon(
auto axis_attr = rewriter.getI32IntegerAttr(axis_val);
shape_vec[axis_val] = 1;
RankedTensorType reduce_type = RankedTensorType::get(
shape_vec,
reduce_element_type);
RankedTensorType reduce_type =
RankedTensorType::get(shape_vec, reduce_element_type);
auto reduce_op = CreateOpAndInfer<T>(rewriter, op->getLoc(), reduce_type,
val, axis_attr);

View File

@ -176,7 +176,8 @@ std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
// Default template creates a constant tensor in T.
template <typename T>
std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
ArrayRef<T> vec, ArrayRef<int64_t> shape, std::optional<Type> dtype) {
ArrayRef<T> vec, ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -209,7 +210,8 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
template <>
std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
Operation *op, ArrayRef<APInt> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -238,7 +240,8 @@ std::optional<Value> getConstTensor<APInt>(PatternRewriter &rewriter,
template <>
std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
Operation *op, ArrayRef<float> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype) {
ArrayRef<int64_t> shape,
std::optional<Type> dtype) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
@ -347,23 +350,17 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
}
// Template instantiation
template std::optional<Value> getConstTensor<bool>(PatternRewriter &,
Operation *,
ArrayRef<bool> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<bool>(PatternRewriter &, Operation *, ArrayRef<bool> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,
ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int32_t>(PatternRewriter &, Operation *, ArrayRef<int32_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
template std::optional<Value> getConstTensor<int64_t>(PatternRewriter &,
Operation *,
ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape,
std::optional<Type> dtype);
template std::optional<Value>
getConstTensor<int64_t>(PatternRewriter &, Operation *, ArrayRef<int64_t> vec,
ArrayRef<int64_t> shape, std::optional<Type> dtype);
LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input,
TypeAttr &accType) {

View File

@ -87,7 +87,8 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter,
ValueRange outputs) {
SmallVector<Value, 8> newOperands = inputs;
newOperands.append(outputs.begin(), outputs.end());
return cast<TMTensorOp>(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands));
return cast<TMTensorOp>(
tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands));
}
/// Generic conversion pattern that matches any TMTensorOp. This avoids template

View File

@ -203,8 +203,8 @@ static Value getScalarFloatValue(Value input, Location loc,
//===----------------------------------------------------------------------===//
LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFunctionAttr());
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
*this, getFunctionAttr());
if (!func)
return emitError() << "'@" << getFunction()
<< "' does not reference a valid function";
@ -453,11 +453,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// If the condition is constant, delete the dead branch and inline the live
// branch.
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
auto constantBool = op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
auto constantBool =
op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
if (!constantBool)
return rewriter.notifyMatchFailure(op, "non-constant condition");
replaceOpWithRegion(
rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion());
replaceOpWithRegion(rewriter, op,
constantBool.getValue() ? op.getThenRegion()
: op.getElseRegion());
return success();
});
// If the thenRegion and elseRegion yield the same Value's, then use those
@ -515,14 +517,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
continue;
newResultTypes.push_back(op->getResult(i).getType());
}
auto newIf =
rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes, op.getCondition());
auto newIf = rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes,
op.getCondition());
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
newIf.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
newIf.getElseRegion().end());
newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase);
newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase);
newIf.getThenRegion().front().getTerminator()->eraseOperands(
resultsToErase);
newIf.getElseRegion().front().getTerminator()->eraseOperands(
resultsToErase);
SmallVector<Value> replacementValues;
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
if (resultsToErase[i])
@ -900,8 +904,8 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
op, op.getType(), lhs, getRhsDevice.getResult(),
getRhsDtype.getResult(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(),
op.getMemoryFormat());
return success();
});
}
@ -2045,7 +2049,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
// compiler treat the size as having value semantics?
// There's a small number of such ops, and they are marked as `inplace_view`
// in PyTorch's `native_functions.yaml` file.
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(), op.getIdx());
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(),
op.getIdx());
return success();
});
}
@ -2073,11 +2078,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
auto lhsListConstruct = op.getA().getDefiningOp<Torch::PrimListConstructOp>();
auto lhsListConstruct =
op.getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
return failure();
auto rhsListConstruct = op.getB().getDefiningOp<Torch::PrimListConstructOp>();
auto rhsListConstruct =
op.getB().getDefiningOp<Torch::PrimListConstructOp>();
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
return failure();
@ -2195,7 +2202,8 @@ LogicalResult PrimTupleConstructOp::verify() {
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();
@ -2245,7 +2253,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns(
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();
@ -2400,9 +2409,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenAliasOp
//===----------------------------------------------------------------------===//
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) {
return getOperand();
}
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); }
//===----------------------------------------------------------------------===//
// AtenFloordivIntOp
@ -2484,10 +2491,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
int64_t start, end, step;
if (matchPattern(getStart(), m_TorchConstantInt(&start)) &&
matchPattern(getEnd(), m_TorchConstantInt(&end)) &&
matchPattern(getStep(), m_TorchConstantInt(&step))
&& step == 1
&& start == 0
&& end == std::numeric_limits<int64_t>::max())
matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 &&
start == 0 && end == std::numeric_limits<int64_t>::max())
return getOperand(0);
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
@ -2955,7 +2960,6 @@ LogicalResult AtenPermuteOp::verify() {
<< " elements, the output has rank " << outRank << '.';
}
// Initialization of the reverse permutation. -1 denotes an unknown
// permutation index.
SmallVector<int64_t> reversePermutation(outRank, -1);

View File

@ -556,9 +556,9 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
// Unfortunately the generated implementations aren't visible/exposed ("static" linkage)
// and the predicates themselves can't be added/used in the specification of the parameters
// of the Torch_DictType.
// Unfortunately the generated implementations aren't visible/exposed ("static"
// linkage) and the predicates themselves can't be added/used in the
// specification of the parameters of the Torch_DictType.
static bool isAnyTorchDictKeyType(Type type) {
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||

View File

@ -457,7 +457,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
return success();
}
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
Location loc, Type outType,
Value input,
@ -1269,7 +1268,8 @@ public:
};
} // namespace
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp`
// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into
// `AtenMinDimOp`
namespace {
template <typename OpTy, typename DecompOpTy>
class DecomposeAtenArgMinMaxOp : public OpRewritePattern<OpTy> {
@ -1300,9 +1300,9 @@ public:
.cast<BaseTensorType>();
// If the dim type is `NoneType` i.e. reduce along all the dimensions.
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input
// tensor is flattened to 1d tensor and then the reduction happens on the
// 0th dimension.
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
// first the input tensor is flattened to 1d tensor and then the reduction
// happens on the 0th dimension.
if (dim.getType().isa<Torch::NoneType>()) {
BaseTensorType flattenType =
inputType
@ -1318,8 +1318,8 @@ public:
Value resultArg =
rewriter
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType,
input, dim, keepDim)
.create<DecompOpTy>(loc, valueTensorType, indicesTensorType, input,
dim, keepDim)
.getIndices();
rewriter.replaceOp(op, resultArg);
@ -1961,8 +1961,10 @@ public:
double alpha = 1.6732632423543772848170429916717;
// Create constants for λ and α
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(scale));
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(alpha));
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(scale));
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(alpha));
// Create zero tensor for comparison
Value constantZero =
@ -1972,17 +1974,21 @@ public:
// Calculate positive and negative parts
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value positiveOutput = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value minZeroX =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(
loc, resType, expInput, constantOne, constantOne);
Value negativeOutput = rewriter.create<AtenMulScalarOp>(
loc, resType, expInputMinusOne, alphaVal);
// Multiply the result by λ
Value seluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
seluOutput = rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
seluOutput =
rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);
// Replace the original operation
rewriter.replaceOp(op, seluOutput);
@ -2594,7 +2600,8 @@ namespace {
static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
Location loc, Value input,
int64_t dimA, int64_t dimB,
int64_t dimA,
int64_t dimB,
Value &transposed) {
Type transposedType;
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
@ -2615,17 +2622,18 @@ namespace {
LogicalResult matchAndRewrite(AtenConvTbcOp op,
PatternRewriter &rewriter) const override {
Value emptyList = rewriter.create<PrimListConstructOp>(
op.getLoc(),
Torch::ListType::get(Torch::IntType::get(op.getContext())),
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
Value oneList = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(1))});
SmallVector<Value>{rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(1))});
Value padding = rewriter.create<PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())),
SmallVector<Value>{op.getPad()});
Value groups = rewriter.create<Torch::ConstantIntOp>(op.getLoc(), rewriter.getI64IntegerAttr(1));
Value groups = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(1));
// convtbc has WNC layout for input and output
// and WCF layout for weight
@ -2634,37 +2642,45 @@ namespace {
Value selfWnc = op.getSelf();
Value selfNwc;
Value selfNcw;
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc)))
return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc");
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw)))
return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw");
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc,
0, 1, selfNwc)))
return rewriter.notifyMatchFailure(op,
"failed to transpose input to Nwc");
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc,
1, 2, selfNcw)))
return rewriter.notifyMatchFailure(op,
"failed to transpose input to Ncw");
Value weightWcf = op.getWeight();
Value weightFcw;
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw)))
return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw");
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
weightWcf, 0, 2, weightFcw)))
return rewriter.notifyMatchFailure(op,
"failed to transpose weight to Fcw");
Value outputNcw = rewriter.create<AtenConvolutionOp>(
op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/oneList,
op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(),
/*stride*/ oneList,
/*padding*/ padding, /*dilation*/ oneList,
/*transpose*/ cstFalse, /*output_padding*/ emptyList,
groups);
/*transpose*/ cstFalse, /*output_padding*/ emptyList, groups);
// convert output from Ncw to Wnc
Value outputNwc;
Value outputWnc;
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc)))
return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc");
if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc)))
return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc");
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
outputNcw, 1, 2, outputNwc)))
return rewriter.notifyMatchFailure(op,
"failed to transpose output to Nwc");
if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(),
outputNwc, 0, 1, outputWnc)))
return rewriter.notifyMatchFailure(op,
"failed to transpose output to Wnc");
rewriter.replaceOp(op, outputWnc);
return success();
}
};
}
} // namespace
// Decompose aten.conv1d to aten.convolution
namespace {
@ -3815,8 +3831,8 @@ public:
/*device=*/none, /*pin_memory=*/none, /*memory_format=*/none);
Value stdRandN =
rewriter.create<AtenMulScalarOp>(loc, resultType, randN, std);
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN,
mean, /*alpha=*/one);
rewriter.replaceOpWithNewOp<AtenAddScalarOp>(op, resultType, stdRandN, mean,
/*alpha=*/one);
return success();
}
};
@ -6654,8 +6670,10 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTranspose2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArangeStartOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgmaxOp, AtenMaxDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenArgMinMaxOp<AtenArgminOp, AtenMinDimOp>>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSquareOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdOp>(patterns);
@ -6768,8 +6786,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenConv2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv3dOp>(patterns);
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.maxIterations = GreedyRewriteConfig::kNoLimit;

View File

@ -170,8 +170,8 @@ private:
auto attr = std::get<1>(t);
nameStack.push_back(attr.getName().str());
if (attr.getType().isa<NnModuleType>()) {
if (failed(
recursivelyTraverse(slot.getValue().getDefiningOp<NnModuleOp>())))
if (failed(recursivelyTraverse(
slot.getValue().getDefiningOp<NnModuleOp>())))
return failure();
} else if (usedSlots.find(slot) != usedSlots.end()) {
// Only create the GlobalSlotOp if the slot is used at all.
@ -190,8 +190,8 @@ private:
}
for (auto method : classType.getOps<MethodOp>()) {
nameStack.push_back(method.getName().str());
funcLinkageInfo[{nnModule,
symbolTable.lookup<func::FuncOp>(method.getFunction())}] =
funcLinkageInfo[{
nnModule, symbolTable.lookup<func::FuncOp>(method.getFunction())}] =
LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()};
nameStack.pop_back();
}
@ -501,21 +501,24 @@ static LogicalResult rewriteMonomorphizedFuncClone(
SmallVector<Operation *> toErase;
auto handlePrimSetAttr = [&](PrimSetAttrOp op) {
auto instance = mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
auto instance =
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
SlotOp affectedSlot;
for (auto slot : instance.getOps<SlotOp>()) {
if (slot.getName() == op.getName())
affectedSlot = slot;
}
OpBuilder(op).create<GlobalSlotSetOp>(
op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(),
op.getLoc(),
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(),
op.getValue());
toErase.push_back(op);
return WalkResult::advance();
};
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
if (!op.getType().isa<NnModuleType>()) {
auto instance = mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
auto instance =
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
SlotOp affectedSlot;
for (auto slot : instance.getOps<SlotOp>()) {
if (slot.getName() == op.getName())

View File

@ -163,7 +163,8 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) {
}
if (auto globalSlotSet = dyn_cast<Torch::GlobalSlotSetOp>(op)) {
auto *state = getOrCreate<InlineGlobalSlotsAnalysisState>(
getProgramPoint<FlatSymbolRefProgramPoint>(globalSlotSet.getSlotAttr()));
getProgramPoint<FlatSymbolRefProgramPoint>(
globalSlotSet.getSlotAttr()));
propagateIfChanged(state, state->setSafe(false));
}
// Save the InitializeGlobalSlotsOp for later referencee
@ -211,8 +212,8 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) {
auto it =
llvm::find(initializeGlobalSlotsOp.getSlotSymNames(),
static_cast<Attribute>(flatSymbolRefPoint->getValue()));
Value value = initializeGlobalSlotsOp->getOperand(
std::distance(initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
Value value = initializeGlobalSlotsOp->getOperand(std::distance(
initializeGlobalSlotsOp.getSlotSymNames().begin(), it));
auto *flatSymbolRefState =
getOrCreateFor<InlineGlobalSlotsAnalysisState>(value,
flatSymbolRefPoint);
@ -331,7 +332,8 @@ class InlineGlobalSlotsPass
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
auto slotSymName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
auto slotSymName =
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
Value operand = initialize.getOperand(i);
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>());
@ -405,7 +407,8 @@ class InlineGlobalSlotsPass
SmallVector<Attribute> newSlotSymNames;
SmallVector<Value> newInitialValues;
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
auto slotSymName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
auto slotSymName =
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
if (!safeToInline.count(slotSymName)) {
newSlotSymNames.push_back(slotSymName);
newInitialValues.push_back(initialize.getOperand(i));

View File

@ -202,15 +202,16 @@ static bool satisfiesBackendContract(ModuleOp module,
// Check for unimplemented operators first to give more direct diagnostics.
walkResult0 = module.walk([&](Torch::OperatorOp op) {
if (llvm::all_of(op.getResults(), [&op](auto res) {
return succeeded(
checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false));
return succeeded(checkType(op.getOperation(), res.getType(),
/*actuallyEmitDiagnostics=*/false));
})) {
return WalkResult::advance();
}
if (actuallyEmitDiagnostics) {
op->emitError("unsupported by backend contract: Unimplemented operator '"
+ op.getName() + "'");
op->emitError(
"unsupported by backend contract: Unimplemented operator '" +
op.getName() + "'");
}
return WalkResult::interrupt();
});
@ -309,12 +310,14 @@ public:
<< " iterations of the simplification pipeline\n";
});
}
private:
llvm::StringSet<> backendLegalOpsSet;
};
class VerifyBackendContractNoDecompositionsPass
: public VerifyBackendContractNoDecompositionsBase<VerifyBackendContractNoDecompositionsPass> {
: public VerifyBackendContractNoDecompositionsBase<
VerifyBackendContractNoDecompositionsPass> {
public:
VerifyBackendContractNoDecompositionsPass() = default;

View File

@ -158,9 +158,11 @@ void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library,
}
}
FailureOr<Value> Torch::adjustFunctionArg(
OpBuilder &b, Location loc, Value operand, Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)> baseTransformation) {
FailureOr<Value>
Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
Type desiredType,
function_ref<Value(OpBuilder &, Location, Value, Type)>
baseTransformation) {
operand = baseTransformation(b, loc, operand, desiredType);
// No need for adjustment if they already match.

View File

@ -90,7 +90,8 @@ public:
PatternRewriter &rewriter) const override {
SmallVector<std::optional<int64_t>> ranks;
SmallVector<int64_t> dtypes;
if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) {
if (!matchPattern(op.getRanks(),
m_TorchListOfOptionalConstantInts(ranks))) {
return rewriter.notifyMatchFailure(
op, "Expected `ranks` to be a list of optional constant ints");
}

View File

@ -54,13 +54,13 @@ void TorchConversionDialect::initialize() {
addInterfaces<TorchConversionInlinerInterface>();
}
//===----------------------------------------------------------------------===//
// Constant materializer.
//===----------------------------------------------------------------------===//
Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Attribute value,
Type type,
Location loc) {
if (auto integerType = type.dyn_cast<Torch::IntType>())
return builder.create<Torch::ConstantIntOp>(loc, value.cast<IntegerAttr>());

View File

@ -7,8 +7,8 @@
//
//===----------------------------------------------------------------------===//
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
using namespace mlir;
using namespace mlir::torch;
@ -57,8 +57,8 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
typeConverter.addConversion([](Torch::BoolType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 1);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
IntegerType type, ValueRange inputs,
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
Location loc) -> std::optional<Value> {
// Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 1 && type.isSignless()))
@ -83,8 +83,8 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
typeConverter.addConversion([](Torch::IntType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 64);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
IntegerType type, ValueRange inputs,
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
Location loc) -> std::optional<Value> {
// Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless()))
@ -112,8 +112,8 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
typeConverter.addConversion([](Torch::FloatType type) -> std::optional<Type> {
return Float64Type::get(type.getContext());
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
Float64Type type, ValueRange inputs,
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, Float64Type type, ValueRange inputs,
Location loc) -> std::optional<Value> {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::FloatType>());
@ -133,11 +133,12 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
TypeConverter &typeConverter) {
target.addLegalOp<TorchConversion::GeneratorToI64Op,
TorchConversion::I64ToGeneratorOp>();
typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional<Type> {
typeConverter.addConversion(
[](Torch::GeneratorType type) -> std::optional<Type> {
return IntegerType::get(type.getContext(), 64);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
IntegerType type, ValueRange inputs,
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, IntegerType type, ValueRange inputs,
Location loc) -> std::optional<Value> {
// Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless()))

View File

@ -18,8 +18,8 @@
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
using namespace mlir;
using namespace mlir::torch;
@ -65,7 +65,8 @@ public:
auto getConstantIntegerFromDefiningOp = [](Value operand,
int &extractedInt) {
auto castOp = dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
auto castOp =
dyn_cast<mlir::UnrealizedConversionCastOp>(operand.getDefiningOp());
if (!castOp) {
return failure();
}
@ -83,7 +84,8 @@ public:
return failure();
}
int unpackedBitWidth;
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) {
if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth,
unpackedBitWidth))) {
return failure();
}
if (unpackedBitWidth !=
@ -103,32 +105,35 @@ public:
// expand lhs
std::vector<int64_t> lhsExpandedShape = {lhsShape[0], lhsShape[1],
lhsReductDimSize / gs, gs};
RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType);
RankedTensorType lhsExpandedType =
RankedTensorType::get(lhsExpandedShape, elementType);
SmallVector<ReassociationIndices, 4> lhsReassociation = {{0}, {1}, {2, 3}};
Value lhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
loc, lhsExpandedType, lhs, lhsReassociation);
// expand rhs
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs};
RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType);
std::vector<int64_t> rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs,
gs};
RankedTensorType rhsExpandedType =
RankedTensorType::get(rhsExpandedShape, rhsElementType);
SmallVector<ReassociationIndices, 4> rhsReassociation = {{0}, {1, 2}};
Value rhsExpanded = rewriter.create<tensor::ExpandShapeOp>(
loc, rhsExpandedType, rhsQuant, rhsReassociation);
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
Value emptyDequant = rewriter.create<tensor::EmptyOp>(
loc, rhsExpandedShape, elementType);
Value emptyDequant =
rewriter.create<tensor::EmptyOp>(loc, rhsExpandedShape, elementType);
SmallVector<Value> dynDims;
for (int i = 0; i < lhsType.getRank(); i++) {
if (lhsType.isDynamicDim(i)) {
dynDims.push_back(rewriter.create<tensor::DimOp>(loc, lhs, i));
}
}
Value empty = rewriter.create<tensor::EmptyOp>(
loc, resultShape, elementType, dynDims);
Value output = rewriter.create<linalg::FillOp>(
loc, cst0, empty).getResult(0);
Value empty = rewriter.create<tensor::EmptyOp>(loc, resultShape,
elementType, dynDims);
Value output =
rewriter.create<linalg::FillOp>(loc, cst0, empty).getResult(0);
AffineExpr d0, d1, d2, d3, d4;
bindDims(getContext(), d0, d1, d2, d3, d4);
@ -141,12 +146,12 @@ public:
SmallVector<AffineMap, 4> dqIndexingMaps = {map, map1, map1, map};
SmallVector<AffineMap, 4> matIndexingMaps = {map2, map3, map4};
SmallVector<utils::IteratorType> dequantIteratorTypes(3, utils::IteratorType::parallel);
SmallVector<utils::IteratorType> dequantIteratorTypes(
3, utils::IteratorType::parallel);
SmallVector<utils::IteratorType> matmulIteratorTypes = {
utils::IteratorType::parallel, utils::IteratorType::parallel,
utils::IteratorType::parallel, utils::IteratorType::reduction,
utils::IteratorType::reduction
};
utils::IteratorType::reduction};
Value rhsDequant =
rewriter
@ -157,9 +162,12 @@ public:
/*iteratorTypes=*/dequantIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value w = args[0], scale = args[1], zeroPoint = args[2];
Value extw = b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
Value fp_extw = b.create<arith::UIToFPOp>(loc, rewriter.getF16Type(), extw);
Value shifted = b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
Value extw =
b.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), w);
Value fp_extw = b.create<arith::UIToFPOp>(
loc, rewriter.getF16Type(), extw);
Value shifted =
b.create<arith::SubFOp>(loc, fp_extw, zeroPoint);
Value dqw = b.create<arith::MulFOp>(loc, shifted, scale);
b.create<linalg::YieldOp>(loc, dqw);
})
@ -168,8 +176,8 @@ public:
Value matmulDequant =
rewriter
.create<linalg::GenericOp>(
loc, output.getType(),
ValueRange{lhsExpanded, rhsDequant}, output,
loc, output.getType(), ValueRange{lhsExpanded, rhsDequant},
output,
/*indexingMaps=*/matIndexingMaps,
/*iteratorTypes=*/matmulIteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
@ -188,7 +196,8 @@ public:
namespace {
class ConvertCustomQuantOpPass
: public TorchConversion::ConvertCustomQuantOpBase<ConvertCustomQuantOpPass> {
: public TorchConversion::ConvertCustomQuantOpBase<
ConvertCustomQuantOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect>();
registry.insert<func::FuncDialect>();
@ -213,8 +222,8 @@ class ConvertCustomQuantOpPass
target.addIllegalOp<OperatorOp>();
patterns.add<ConvertCustomQuantizedMatmulOp>(typeConverter, context);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};

View File

@ -33,7 +33,6 @@ using namespace mlir::torch;
using namespace mlir::torch::TorchConversion;
using namespace TMTensor;
namespace {
class VerifyLinalgOnTensorsBackendContractPass
: public VerifyLinalgOnTensorsBackendContractBase<
@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
// doesn't unnecessarily spew out the entire module.
emitError(module.getLoc())
<< "Module does not conform to the linalg-on-tensors backend contract. "
<< "Module does not conform to the linalg-on-tensors backend "
"contract. "
"See dialect conversion legality information above.";
return signalPassFailure();
}

View File

@ -45,7 +45,8 @@ class VerifyStablehloBackendContractPass
ConversionTarget target(*context);
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(
opHasLegalTypes);
// Shape operations.
target.addDynamicallyLegalOp<shape::ShapeOfOp>(opHasLegalTypes);

View File

@ -35,14 +35,14 @@ TorchMlirBackendData::TorchMlirBackendData(
: BackendData(device, shape), info_(info) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Scalar& scalar, BackendDevice device)
TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar,
BackendDevice device)
: BackendData(device, Shape(scalar.type(), {})),
info_(std::make_shared<TorchMlirBackendData::Info>(scalar)) {
PRINT_FUNCTION();
}
TorchMlirBackendData::TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape)
TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor,
BackendDevice device, Shape shape)
: BackendData(device, shape),
info_(std::make_shared<TorchMlirBackendData::Info>(tensor)) {
PRINT_FUNCTION();
@ -55,8 +55,7 @@ BackendData::Handle TorchMlirBackendData::GetHandle() {
void TorchMlirBackendData::Assign(const BackendData &data) {
const TorchMlirBackendData *torch_mlir_data =
dynamic_cast<const TorchMlirBackendData *>(&data);
TORCH_CHECK(
torch_mlir_data,
TORCH_CHECK(torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
info_ = torch_mlir_data->info_;
@ -99,8 +98,9 @@ BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar(
return std::make_shared<TorchMlirBackendData>(scalar, device);
}
BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const {
BackendDataPtr
TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device,
const Shape &shape) const {
PRINT_FUNCTION();
return std::make_shared<TorchMlirBackendData>(device, shape);
}
@ -122,8 +122,7 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
TorchMlirBackendData *torch_mlir_data =
dynamic_cast<TorchMlirBackendData *>(data.get());
TORCH_CHECK(
torch_mlir_data,
TORCH_CHECK(torch_mlir_data,
"Invalid Backend Data Pointer. Expected TorchMlirBackendData.");
TorchMlirBackendData::Info *info =
@ -141,7 +140,8 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData(
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string &name, BackendDevice device,
c10::ArrayRef<const Node*> post_order, Util::EmissionMap emit_status) const {
c10::ArrayRef<const Node *> post_order,
Util::EmissionMap emit_status) const {
PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device),
@ -149,8 +149,9 @@ std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
std::forward<Util::EmissionMap>(emit_status));
}
std::unique_ptr<LoweringContext> TorchMlirBackendImpl::CreateLoweringContext(
const std::string& name, BackendDevice device) const {
std::unique_ptr<LoweringContext>
TorchMlirBackendImpl::CreateLoweringContext(const std::string &name,
BackendDevice device) const {
PRINT_FUNCTION();
return std::make_unique<TorchMlirLoweringContext>(
name, std::forward<BackendDevice>(device));
@ -175,8 +176,7 @@ at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const {
// Query all available backend devices
std::vector<BackendDevice> TorchMlirBackendImpl::GetBackendDevices() const {
PRINT_FUNCTION();
return {
GetBackendDevice(c10::Device(c10::kLazy, 0)),
return {GetBackendDevice(c10::Device(c10::kLazy, 0)),
GetBackendDevice(c10::Device(c10::kCPU, 0))};
}

View File

@ -50,10 +50,11 @@ public:
};
TorchMlirBackendData(BackendDevice device, Shape shape);
TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr<BackendData::Info> info);
TorchMlirBackendData(BackendDevice device, Shape shape,
std::shared_ptr<BackendData::Info> info);
TorchMlirBackendData(const at::Scalar &scalar, BackendDevice device);
TorchMlirBackendData(
const at::Tensor& tensor, BackendDevice device, Shape shape);
TorchMlirBackendData(const at::Tensor &tensor, BackendDevice device,
Shape shape);
virtual BackendData::Handle GetHandle() override;
@ -91,19 +92,22 @@ public:
* Data Transfer
* */
virtual BackendDataPtr MakeComputationDataFromTensor(
const at::Tensor& tensor, const Shape& shape,
virtual BackendDataPtr
MakeComputationDataFromTensor(const at::Tensor &tensor, const Shape &shape,
const BackendDevice &device) const override;
virtual BackendDataPtr MakeComputationDataFromScalar(
const at::Scalar& scalar, const BackendDevice& device) const override;
virtual BackendDataPtr
MakeComputationDataFromScalar(const at::Scalar &scalar,
const BackendDevice &device) const override;
virtual BackendDataPtr CreateDataPlaceholder(
const BackendDevice& device, const Shape& shape) const override;
virtual BackendDataPtr
CreateDataPlaceholder(const BackendDevice &device,
const Shape &shape) const override;
// Gets backend data if the node is a device data node. Otherwise returns
// nullptr.
virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override;
virtual BackendDataPtr
GetComputationDataFromNode(const Node *) const override;
virtual at::Tensor MakeTensorFromComputationData(
const BackendDataPtr data,
@ -113,13 +117,14 @@ public:
* Lowering, Compilation, Execution
* */
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device,
virtual std::unique_ptr<LoweringContext>
CreateLoweringContext(const std::string &name, BackendDevice device,
c10::ArrayRef<const Node *> post_order,
Util::EmissionMap emit_status) const override;
virtual std::unique_ptr<LoweringContext> CreateLoweringContext(
const std::string& name, BackendDevice device) const override;
virtual std::unique_ptr<LoweringContext>
CreateLoweringContext(const std::string &name,
BackendDevice device) const override;
// TODO(whc) need to keep this?
// virtual std::vector<std::string> GetCompilationDevices(

View File

@ -16,15 +16,13 @@ namespace torch {
namespace lazy {
DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed)
: TorchMlirNode(
op, operands, /*num_outputs=*/1,
: TorchMlirNode(op, operands, /*num_outputs=*/1,
/* hash_seed */ HashCombine(op.hash(), hash_seed)) {}
std::string DimensionNode::ToString() const { return "DimensionNode"; }
SizeNode::SizeNode(Value input, size_t dim)
: DimensionNode(
OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input},
MHash(dim)),
dim_(dim){};
@ -40,7 +38,8 @@ SizeAdd::SizeAdd(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){};
int64_t SizeAdd::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() +
return dynamic_cast<const DimensionNode *>(operand(0).node)
->getStaticValue() +
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
}
@ -50,7 +49,8 @@ SizeMul::SizeMul(Value a, Value b)
: DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){};
int64_t SizeMul::getStaticValue() const {
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() *
return dynamic_cast<const DimensionNode *>(operand(0).node)
->getStaticValue() *
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
}
@ -64,7 +64,8 @@ int64_t SizeDiv::getStaticValue() const {
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue() !=
0,
"Can't divide a dimension by zero");
return dynamic_cast<const DimensionNode*>(operand(0).node)->getStaticValue() /
return dynamic_cast<const DimensionNode *>(operand(0).node)
->getStaticValue() /
dynamic_cast<const DimensionNode *>(operand(1).node)->getStaticValue();
}

View File

@ -12,14 +12,14 @@
#include <iostream>
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include <torch/csrc/lazy/core/config.h>
#include "torch-mlir-c/Registration.h"
#include "torch-mlir-c/Transforms.h"
#include "mlir-c/IR.h"
#include "mlir-c/Pass.h"
#include "torch-mlir-c/Registration.h"
#include "torch-mlir-c/Transforms.h"
#include <torch/csrc/jit/api/compilation_unit.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/lazy/core/config.h>
#include <torch/csrc/lazy/core/lazy_graph_executor.h>
#include "backend_impl.h"
#include "jit_ir_importer/function_importer.h"
@ -38,8 +38,8 @@ namespace lazy {
// TorchMlir Lowering Context
///////////////////////////////////////////////////////////////////////////////
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string& name, BackendDevice device)
TorchMlirLoweringContext::TorchMlirLoweringContext(const std::string &name,
BackendDevice device)
: LoweringContext(name, std::forward<BackendDevice>(device)),
graph_(std::make_shared<torch::jit::Graph>()),
function_(
@ -50,7 +50,8 @@ TorchMlirLoweringContext::TorchMlirLoweringContext(
TorchMlirLoweringContext::TorchMlirLoweringContext(
const std::string &name, BackendDevice device,
c10::ArrayRef<const torch::lazy::Node*> post_order, Util::EmissionMap emit_status)
c10::ArrayRef<const torch::lazy::Node *> post_order,
Util::EmissionMap emit_status)
: LoweringContext(
name, std::forward<BackendDevice>(device),
std::forward<c10::ArrayRef<const torch::lazy::Node *>>(post_order),
@ -90,9 +91,9 @@ void TorchMlirLoweringContext::SetUpAlias(
bool TorchMlirLoweringContext::CheckResultShape(
const BackendDataPtr &parameter_data, size_t result_idx) {
TORCH_CHECK(
result_idx < root_tuple_.size(), "Tried getting result shape at index ",
result_idx, " which is out of bounds!");
TORCH_CHECK(result_idx < root_tuple_.size(),
"Tried getting result shape at index ", result_idx,
" which is out of bounds!");
torch::jit::Value *output = root_tuple_[result_idx];
@ -120,9 +121,10 @@ size_t TorchMlirLoweringContext::AddResult(const Output& output) {
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void TorchMlirLoweringContext::AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) {
void TorchMlirLoweringContext::AddParameter(const torch::lazy::Output &output,
size_t index,
const torch::lazy::Shape &shape,
const std::string &name) {
UNIMPLEMENTED_FUNCTION_ERROR();
}
@ -152,7 +154,6 @@ ComputationPtr TorchMlirLoweringContext::Build() {
/*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; },
/*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true});
// Convert MlirOperation to MlirModule.
MlirLocation loc = mlirLocationUnknownGet(mlir_context_);
MlirModule module_op = mlirModuleCreateEmpty(loc);
@ -162,14 +163,10 @@ ComputationPtr TorchMlirLoweringContext::Build() {
// Apply passes to verify generated MLIR.
auto pass_manager = mlirPassManagerCreate(mlir_context_);
mlirPassManagerAddOwnedPass(
pass_manager,
mlirCreateVerifyBackendContractNoDecompositions()
);
pass_manager, mlirCreateVerifyBackendContractNoDecompositions());
MlirLogicalResult result = mlirPassManagerRunOnOp(
pass_manager,
mlirModuleGetOperation(module_op)
);
MlirLogicalResult result =
mlirPassManagerRunOnOp(pass_manager, mlirModuleGetOperation(module_op));
if (mlirLogicalResultIsFailure(result)) {
throw std::runtime_error("MLIR verification has failed.");
@ -178,9 +175,11 @@ ComputationPtr TorchMlirLoweringContext::Build() {
return CreateComputation(module_op);
}
ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) {
return std::make_shared<TorchMlirComputation>(
module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_);
ComputationPtr
TorchMlirLoweringContext::CreateComputation(MlirModule module_op) {
return std::make_shared<TorchMlirComputation>(module_op, mlir_context_,
graph_, parameter_names_,
input_output_aliases_);
}
torch::jit::Value *TorchMlirLoweringContext::GetOutputOp(const Output &output) {
@ -195,15 +194,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) {
// At this point the output better be present, otherwise there is an issue
// with the lowering code.
it = emitted_outputs_.find(output);
TORCH_CHECK(
it != emitted_outputs_.end(),
TORCH_CHECK(it != emitted_outputs_.end(),
"No MLIR operation emitted for output: ", output.ToString());
}
return it->second;
}
void TorchMlirLoweringContext::AssignOutputOp(
const Output& output, torch::jit::Value* op) {
void TorchMlirLoweringContext::AssignOutputOp(const Output &output,
torch::jit::Value *op) {
PRINT_FUNCTION();
auto torch_mlir_node =
@ -234,17 +232,13 @@ void TorchMlirLoweringContext::AssignOutputOp(
});
if (!source_files.empty()) {
op->node()->ss_(
c10::Symbol::attr("source_files"), source_files);
op->node()->ss_(
c10::Symbol::attr("functions"), functions);
op->node()->is_(
c10::Symbol::attr("line_numbers"), line_numbers);
op->node()->ss_(c10::Symbol::attr("source_files"), source_files);
op->node()->ss_(c10::Symbol::attr("functions"), functions);
op->node()->is_(c10::Symbol::attr("line_numbers"), line_numbers);
}
}
auto scope = ::c10::Symbol::scope(metadata.scope);
op->node()->setScope(
c10::make_intrusive<torch::jit::Scope>()->push(scope));
op->node()->setScope(c10::make_intrusive<torch::jit::Scope>()->push(scope));
emitted_outputs_[output] = std::move(op);
}
@ -266,7 +260,8 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
torch::jit::Value *param =
graph_->addInput(c10::str("p", parameters_.size()));
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
auto *info =
dynamic_cast<TorchMlirBackendData::Info *>(mlir_data->mlir_info());
TORCH_CHECK(info, "Expected TorchMlirBackendData::Info");
if (info->scalar.has_value()) {
auto &scalar = info->scalar.value();
@ -275,8 +270,8 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) {
} else if (scalar.isIntegral(true)) {
param->setType(c10::IntType::get());
} else {
TORCH_CHECK(
false, "Unhandled scalar type: ", c10::toString(scalar.type()));
TORCH_CHECK(false,
"Unhandled scalar type: ", c10::toString(scalar.type()));
}
} else {
// Save parameter shape information.
@ -313,8 +308,8 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) {
// Sync vector of c10::Argument with type specified from parallel list of
// jit::Value. There must be a 1:1 map between elements of args and values.
std::vector<c10::Argument> sync_argument_types(
const std::vector<c10::Argument>& args,
std::vector<c10::Argument>
sync_argument_types(const std::vector<c10::Argument> &args,
c10::ArrayRef<torch::jit::Value *> values) {
TORCH_CHECK(
args.size() == values.size(),
@ -377,9 +372,7 @@ TorchMlirComputation::TorchMlirComputation(
}
}
int TorchMlirComputation::parameters_size() const {
return num_parameters_;
}
int TorchMlirComputation::parameters_size() const { return num_parameters_; }
const std::vector<torch::lazy::Shape> &
TorchMlirComputation::parameter_shapes() const {
@ -392,7 +385,8 @@ const std::vector<std::string>& TorchMlirComputation::parameter_names() const {
return parameter_names_;
}
const std::unordered_map<int, std::string>& TorchMlirComputation::parameters_map() const {
const std::unordered_map<int, std::string> &
TorchMlirComputation::parameters_map() const {
return parameters_map_;
}
@ -411,13 +405,9 @@ MlirOperation TorchMlirComputation::func_op() const {
return mlirBlockGetFirstOperation(block);
}
MlirModule TorchMlirComputation::module_op() const {
return module_op_;
}
MlirModule TorchMlirComputation::module_op() const { return module_op_; }
MlirContext TorchMlirComputation::mlir_context() const {
return mlir_context_;
}
MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; }
const std::string TorchMlirComputation::debug_string() const {
std::stringstream ss;
@ -462,7 +452,8 @@ const std::string TorchMlirComputation::to_string() const {
// Setup flags for MLIR serialization.
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false);
mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss);
mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags,
print_callback, &ss);
return ss.str();
}

View File

@ -39,24 +39,23 @@ public:
};
using InputOutputAliases = std::vector<InputOutputAlias>;
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device);
TorchMlirLoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
TorchMlirLoweringContext(const std::string &name,
torch::lazy::BackendDevice device);
TorchMlirLoweringContext(const std::string &name,
torch::lazy::BackendDevice device,
c10::ArrayRef<const torch::lazy::Node *> post_order,
torch::lazy::Util::EmissionMap emit_status);
void Lower(const Node *node);
// Adds a new input/output alias.
void SetUpAlias(
const std::vector<int64_t>& output_index, int64_t param_number,
const std::vector<int64_t>& param_index,
void SetUpAlias(const std::vector<int64_t> &output_index,
int64_t param_number, const std::vector<int64_t> &param_index,
bool must_alias = false) override;
// Check if parameter shape matches result at index.
bool CheckResultShape(
const BackendDataPtr& parameter_data, size_t result_idx) override;
bool CheckResultShape(const BackendDataPtr &parameter_data,
size_t result_idx) override;
// Adds the given output as a component of the result tuple and returns its
// assigned position within the tuple.
@ -65,9 +64,9 @@ public:
// Associates the given output with the input parameter of the given index and
// shape. Only used for the operator-by-operator execution, mostly for
// debugging purposes.
void AddParameter(
const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape, const std::string& name) override;
void AddParameter(const torch::lazy::Output &output, size_t index,
const torch::lazy::Shape &shape,
const std::string &name) override;
// Build the computation capturing all the operations created with the
// embedded builder (returned by the builder() API).
@ -122,8 +121,7 @@ public:
using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases;
using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias;
TorchMlirComputation(
MlirModule module_op, MlirContext mlir_context,
TorchMlirComputation(MlirModule module_op, MlirContext mlir_context,
const std::shared_ptr<torch::jit::Graph> &graph,
std::unordered_map<int, std::string> parameters_map,
InputOutputAliases input_output_aliases);

View File

@ -10,8 +10,8 @@
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp
//===----------------------------------------------------------------------===//
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/CompositeExplicitAutogradFunctions.h>
#include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/InferSize.h>
#include <ATen/MetaFunctions.h>
@ -33,11 +33,11 @@
#include "generated/LazyIr.h"
#include "generated/LazyNativeFunctions.h"
#include "generated/shape_inference.h"
#include "ops/to_copy.h"
#include "ops/unbind_int.h"
#include "ops/split.h"
#include "ops/index.h"
#include "ops/ivalue.h"
#include "ops/split.h"
#include "ops/to_copy.h"
#include "ops/unbind_int.h"
#include "utils/exception.h"
#include "utils/sys_utils.h"
@ -76,7 +76,8 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
return outs;
}
c10::List<c10::optional<at::Tensor>> to_meta(const c10::List<c10::optional<at::Tensor>>& t_list) {
c10::List<c10::optional<at::Tensor>>
to_meta(const c10::List<c10::optional<at::Tensor>> &t_list) {
c10::List<c10::optional<at::Tensor>> outs;
outs.reserve(t_list.size());
for (const auto &tensor : t_list) {
@ -91,8 +92,8 @@ namespace lazy {
namespace {
at::Tensor CreateLtcTensor(
const at::Tensor& tensor,
at::Tensor
CreateLtcTensor(const at::Tensor &tensor,
const c10::optional<torch::lazy::BackendDevice> &device) {
if (tensor.defined() && device) {
return torch::lazy::CreateAtenFromLtcTensor(
@ -112,13 +113,12 @@ GetLtcDevice(const c10::optional<c10::Device>& device) {
return torch::lazy::atenDeviceToBackendDevice(*device);
}
torch::lazy::Value MaybeExpand(
const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) {
torch::lazy::Value MaybeExpand(const torch::lazy::Value &input,
const torch::lazy::Shape &target_shape) {
if (input.shape().sizes() == target_shape.sizes()) {
return input;
}
return torch::lazy::MakeExpand(
input, target_shape.sizes().vec(),
return torch::lazy::MakeExpand(input, target_shape.sizes().vec(),
/*is_scalar_expand=*/false);
}
@ -128,8 +128,8 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
if (input->dtype() == src->dtype()) {
copy_value = src->GetIrValue();
} else {
copy_value = torch::lazy::MakeCast(
src->GetIrValue(), input->dtype(), src->dtype());
copy_value = torch::lazy::MakeCast(src->GetIrValue(), input->dtype(),
src->dtype());
}
input->SetIrValue(MaybeExpand(copy_value, input->shape()));
} else {
@ -146,15 +146,17 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
// clone is special in LT because we make it a no-op.
// This should be safe to do, because every operator in the LT is functional.
at::Tensor LazyNativeFunctions::clone(
const at::Tensor& self, c10::optional<at::MemoryFormat> memory_format) {
at::Tensor
LazyNativeFunctions::clone(const at::Tensor &self,
c10::optional<at::MemoryFormat> memory_format) {
auto self_lt = torch::lazy::TryGetLtcTensor(self);
return torch::lazy::CreateAtenFromLtcTensor(
self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice()));
}
at::Tensor LazyNativeFunctions::_copy_from(
const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor &self,
const at::Tensor &dst,
bool non_blocking) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
@ -207,8 +209,8 @@ at::Tensor LazyNativeFunctions::_copy_from(
return dst;
}
at::Tensor LazyNativeFunctions::_copy_from_and_resize(
const at::Tensor& self, const at::Tensor& dst) {
at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self,
const at::Tensor &dst) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto dst_tensor = torch::lazy::TryGetLtcTensor(dst);
auto self_tensor = torch::lazy::TryGetLtcTensor(self);
@ -239,8 +241,9 @@ at::Tensor LazyNativeFunctions::_to_copy(
PRINT_FUNCTION();
auto options = self.options();
if (dtype) {
// I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)...
// because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it
// I put each of these setters in a conditional instead of doing
// `self.options().dtype(dtype).layout(layout)... because calling
// .dtype(nullopt) on an options() that already has dtype appears to wipe it
options = options.dtype(dtype);
}
if (layout) {
@ -261,8 +264,9 @@ at::Tensor LazyNativeFunctions::_to_copy(
if (!lazy_self && device && device->type() == c10::kLazy) {
// Case 1: eager->lazy (we create a new lazy tensor)
// See Note [Lazy Tensor Functionalization]
// Invariant: if the functionalization key is in the exclude set, then we're expected
// to return an ordinary tensor, which will be "lifted" into a functional wrapper later.
// Invariant: if the functionalization key is in the exclude set, then we're
// expected to return an ordinary tensor, which will be "lifted" into a
// functional wrapper later.
bool functionalize_output =
!c10::impl::tls_local_dispatch_key_set().excluded_.has(
c10::DispatchKey::Functionalize);
@ -270,7 +274,8 @@ at::Tensor LazyNativeFunctions::_to_copy(
self, options, *device, /*non_blocking=*/non_blocking,
/*functionalize_output=*/functionalize_output);
} else if (device && device->type() != c10::kLazy) {
// Case 2: lazy->eager (forces a graph break since we are materializing a tensor)
// Case 2: lazy->eager (forces a graph break since we are materializing a
// tensor)
TORCH_INTERNAL_ASSERT(lazy_self);
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
@ -278,22 +283,24 @@ at::Tensor LazyNativeFunctions::_to_copy(
auto moved_eager_tensor =
eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true);
return moved_eager_tensor;
} else if (
device && device->type() == c10::kLazy && device->has_index() &&
} else if (device && device->type() == c10::kLazy && device->has_index() &&
device->index() != self.device().index()) {
// Case 3: lazy:0 -> lazy:1
// TODO(whc) what do we actually want to do here?
// option 1: materialize, move eager tensor, create new lazy tensor
// - this should be our default, as it is what would happen before we implemented _to_copy
// - this should be our default, as it is what would happen before we
// implemented _to_copy
// - actually combines case 1 + case 2
// option 2: support multiple devices inside one lazy/TS executor (case 4)
// - but: we may have other assumptions that there is just one device per executor? so don't take this lightly
// - but: we may have other assumptions that there is just one device
// per executor? so don't take this lightly
TORCH_INTERNAL_ASSERT(lazy_self);
auto eager_tensor = lazy_self->ToTensor(/*detached=*/true);
// we move the eager tensor to the 'eager' equivalent of our lazy device
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use
// e.g. if our device is lazy:1, the backend maps that to cuda:1, which is
// what we use
auto eager_device = c10::Device(
torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index());
options = options.device(eager_device);
@ -305,12 +312,14 @@ at::Tensor LazyNativeFunctions::_to_copy(
return torch::lazy::CreateAtenFromLtcTensor(lazy_self);
} else {
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph)
// Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy
// graph)
// Note: captured _to_copy will be executed with real eager tensors, not lazy tensors.
// We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to
// convert an eager tensor back to a lazy one inside the torchscript executor
// lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument
// Note: captured _to_copy will be executed with real eager tensors, not
// lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this
// captured IR, or we will try to convert an eager tensor back to a lazy one
// inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so
// we can safely drop the device argument
device = c10::nullopt;
auto shapes = torch::lazy::compute_shape__to_copy(
@ -327,10 +336,11 @@ at::Tensor LazyNativeFunctions::_to_copy(
}
};
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self, at::IntArrayRef size) {
at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self,
at::IntArrayRef size) {
TORCH_LAZY_FN_COUNTER("lazy::");
return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size));
return LazyNativeFunctions::view_copy_symint(self,
c10::fromIntArrayRefSlow(size));
}
at::Tensor LazyNativeFunctions::t(const at::Tensor &self) {
@ -338,156 +348,180 @@ at::Tensor LazyNativeFunctions::t(const at::Tensor& self) {
return at::functionalization::functionalize_aten_op<ATEN_OP(t)>::call(self);
}
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) {
std::vector<at::Tensor> LazyNativeFunctions::unbind_copy(const at::Tensor &self,
int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim);
LazyTensorPtr lazy_self =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim);
auto out_meta =
at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim);
std::vector<torch::lazy::Shape> shapes;
for (const auto &shape : out_meta) {
shapes.push_back(
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
);
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
}
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, dim};
const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]";
const char *schema_str =
"aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim, std::move(shapes));
node = torch::lazy::MakeNode<UnbindCopyInt>(lazy_self->GetIrValue(), dim,
std::move(shapes));
CacheNode(node);
}
std::vector<at::Tensor> result;
for (size_t i = 0; i < node->num_outputs(); ++i) {
result.push_back(
torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
)
);
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, i), *common_device)));
}
return result;
}
std::vector<at::Tensor> LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) {
std::vector<at::Tensor> LazyNativeFunctions::split_with_sizes_copy_symint(
const at::Tensor &self, c10::SymIntArrayRef split_sizes, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim);
LazyTensorPtr lazy_self =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitWithSizesCopy>(
lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim);
auto out_meta = at::compositeexplicitautogradnonfunctional::
split_with_sizes_copy_symint(self_meta, split_sizes, dim);
std::vector<torch::lazy::Shape> shapes;
for (const auto &shape : out_meta) {
shapes.push_back(
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
);
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
}
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, split_sizes, dim};
const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]";
const char *schema_str = "aten::split_with_sizes_copy(Tensor self, "
"SymInt[] split_sizes, int dim=0) -> Tensor[]";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<SplitWithSizesCopy>(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes));
node = torch::lazy::MakeNode<SplitWithSizesCopy>(
lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim,
std::move(shapes));
CacheNode(node);
}
std::vector<at::Tensor> result;
for (size_t i = 0; i < node->num_outputs(); ++i) {
result.push_back(
torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
)
);
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, i), *common_device)));
}
return result;
}
std::vector<at::Tensor> LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) {
std::vector<at::Tensor>
LazyNativeFunctions::split_copy_symint(const at::Tensor &self,
c10::SymInt split_size, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim);
LazyTensorPtr lazy_self =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<SplitCopyTensor>(
lazy_self->GetIrValue(), GetSymIntValue(split_size), dim);
if (!node) {
auto self_meta = to_meta(self);
auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim);
auto out_meta =
at::compositeexplicitautogradnonfunctional::split_copy_symint(
self_meta, split_size, dim);
std::vector<torch::lazy::Shape> shapes;
for (const auto &shape : out_meta) {
shapes.push_back(
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())
);
torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()));
}
const size_t num_outputs = shapes.size();
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, split_size, dim};
const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]";
const char *schema_str = "aten::split_copy.Tensor(Tensor self, SymInt "
"split_size, int dim=0) -> Tensor[]";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<SplitCopyTensor>(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs);
node = torch::lazy::MakeNode<SplitCopyTensor>(
lazy_self->GetIrValue(), GetSymIntValue(split_size), dim,
std::move(shapes), num_outputs);
CacheNode(node);
}
std::vector<at::Tensor> result;
for (size_t i = 0; i < node->num_outputs(); ++i) {
result.push_back(
torch::lazy::CreateAtenFromLtcTensor(
torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device)
)
);
torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create(
torch::lazy::Value(node, i), *common_device)));
}
return result;
}
at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices) {
at::Tensor LazyNativeFunctions::index(
const at::Tensor &self,
const c10::List<c10::optional<at::Tensor>> &indices) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
LazyTensorPtr lazy_self =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
std::vector<torch::lazy::Value> values;
for (const auto &it : indices) {
c10::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
LazyTensorPtr lazy_tensor =
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
values.push_back(
lazy_tensor
? lazy_tensor->GetIrValue()
: torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
}
auto list = MakeNode<TorchMlirOptionalTensorList>(values);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<IndexTensor>(lazy_self->GetIrValue(), list);
if (!node) {
auto self_meta = to_meta(self);
auto indices_meta = to_meta(indices);
auto out_meta = at::meta::index(self_meta, indices_meta);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, indices};
const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
const char *schema_str =
"aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list, std::move(shapes));
node = torch::lazy::MakeNode<IndexTensor>(lazy_self->GetIrValue(), list,
std::move(shapes));
CacheNode(node);
}
@ -497,40 +531,56 @@ at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List<c
return result;
}
at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List<c10::optional<at::Tensor>> & indices, const at::Tensor & values, bool accumulate) {
at::Tensor LazyNativeFunctions::index_put(
const at::Tensor &self, const c10::List<c10::optional<at::Tensor>> &indices,
const at::Tensor &values, bool accumulate) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto common_device = torch::lazy::GetBackendDevice(self);
TORCH_INTERNAL_ASSERT(common_device);
LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
LazyTensorPtr lazy_self =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device);
LazyTensorPtr lazy_valeus =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device);
std::vector<torch::lazy::Value> indices_vector;
for (const auto &it : indices) {
c10::optional<at::Tensor> tensor = it;
LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
LazyTensorPtr lazy_tensor =
torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor()));
indices_vector.push_back(
lazy_tensor
? lazy_tensor->GetIrValue()
: torch::lazy::Value(MakeNode<IValueConstant>(c10::IValue()), 0));
}
auto indices_list = MakeNode<TorchMlirOptionalTensorList>(indices_vector);
torch::lazy::NodePtr node = torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate);
torch::lazy::NodePtr node =
torch::lazy::ReuseNode<IndexPut>(lazy_self->GetIrValue(), indices_list,
lazy_valeus->GetIrValue(), accumulate);
if (!node) {
auto self_meta = to_meta(self);
auto indices_meta = to_meta(indices);
auto values_meta = to_meta(values);
auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate);
auto out_meta = at::compositeexplicitautograd::index_put(
self_meta, indices_meta, values_meta, accumulate);
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
std::vector<torch::lazy::Shape> shapes{
torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
TORCH_INTERNAL_ASSERT(shapes.size() == 1);
if (torch::lazy::symbolicShapeEnabled()) {
std::vector<torch::jit::IValue> inputs = {self, indices, values};
const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor";
const char *schema_str =
"aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool "
"accumulate=False) -> Tensor";
applySymbolicShapesOnLT(schema_str, inputs, shapes);
}
node = torch::lazy::MakeNode<IndexPut>(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes));
node = torch::lazy::MakeNode<IndexPut>(
lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(),
accumulate, std::move(shapes));
CacheNode(node);
}
@ -542,7 +592,8 @@ at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::Li
// This is needed by the torch.tensor constructor.
// LazyTensor always opts into functionalization.
// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object.
// "lifting" a tensor for functionalization means wrapping it in a
// FunctionalTensorWrapper object.
at::Tensor LazyNativeFunctions::lift(const at::Tensor &tensor) {
TORCH_INTERNAL_ASSERT(
!at::functionalization::impl::isFunctionalTensor(tensor));
@ -555,29 +606,27 @@ at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) {
return at::functionalization::impl::to_functional_tensor(tensor);
}
// All of the below ops correspond to CompositeExplicitAutograd kernels from core
// that call into view operators internally.
// These are all composite ops that LTC can technically re-use / get for free,
// but we need to "functionalize" them to remove the view ops before we can use them.
// All of the below ops correspond to CompositeExplicitAutograd kernels from
// core that call into view operators internally. These are all composite ops
// that LTC can technically re-use / get for free, but we need to
// "functionalize" them to remove the view ops before we can use them.
at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
block_diag)>::call(tensors);
}
at::Tensor LazyNativeFunctions::new_empty_strided_symint(
const at::Tensor& self,
c10::SymIntArrayRef size,
c10::SymIntArrayRef stride,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
const at::Tensor &self, c10::SymIntArrayRef size,
c10::SymIntArrayRef stride, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
if (!device || device->type() == c10::DeviceType::Lazy) {
return at::functionalization::functionalize_aten_op_symint<
ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout,
device, pin_memory);
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
new_empty_strided)>::call(self, size, stride, dtype, layout, device,
pin_memory);
}
// For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu")
// we need to avoid explicit functionalization. To do that we create regular cpu tensors.
// For cases when device != lazy, for example:
// lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit
// functionalization. To do that we create regular cpu tensors.
at::Tensor t = at::empty_symint(
size, (dtype ? dtype : c10::optional<at::ScalarType>(self.scalar_type())),
(layout ? layout : c10::optional<at::Layout>(self.layout())), device,
@ -585,43 +634,39 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint(
return t.as_strided_symint(size, stride, /*storage_offset=*/0);
}
at::Tensor LazyNativeFunctions::narrow_copy_symint(
const at::Tensor& self,
at::Tensor LazyNativeFunctions::narrow_copy_symint(const at::Tensor &self,
int64_t dim,
c10::SymInt start,
c10::SymInt length) {
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
narrow_copy)>::call(self, dim, start, length);
}
at::Tensor LazyNativeFunctions::pixel_shuffle(
const at::Tensor& self, int64_t upscale_factor) {
at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor &self,
int64_t upscale_factor) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_shuffle)>::call(self, upscale_factor);
}
at::Tensor LazyNativeFunctions::pixel_unshuffle(
const at::Tensor& self, int64_t downscale_factor) {
at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor &self,
int64_t downscale_factor) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
pixel_unshuffle)>::call(self, downscale_factor);
}
at::Tensor LazyNativeFunctions::select_backward(
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim,
int64_t index) {
at::Tensor LazyNativeFunctions::select_backward(const at::Tensor &grad_output,
at::IntArrayRef input_sizes,
int64_t dim, int64_t index) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
select_backward)>::call(grad_output, input_sizes, dim, index);
}
at::Tensor LazyNativeFunctions::slice_backward_symint(
const at::Tensor& grad_output,
at::SymIntArrayRef input_sizes,
int64_t dim,
c10::SymInt start,
c10::SymInt end,
c10::SymInt step) {
const at::Tensor &grad_output, at::SymIntArrayRef input_sizes, int64_t dim,
c10::SymInt start, c10::SymInt end, c10::SymInt step) {
return at::functionalization::functionalize_aten_op_symint<ATEN_OP(
slice_backward)>::call(grad_output, input_sizes, dim, start, end, step);
}
at::Tensor LazyNativeFunctions::diagonal_backward(
const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset,
int64_t dim1, int64_t dim2) {
at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor &grad_output,
at::IntArrayRef input_sizes,
int64_t offset, int64_t dim1,
int64_t dim2) {
return at::functionalization::functionalize_aten_op<ATEN_OP(
diagonal_backward)>::call(grad_output, input_sizes, offset, dim1, dim2);
}
@ -629,8 +674,9 @@ at::Tensor LazyNativeFunctions::_trilinear(
const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3,
at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3,
at::IntArrayRef sumdim, int64_t unroll_dim) {
return at::functionalization::functionalize_aten_op<ATEN_OP(_trilinear)>::
call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim);
return at::functionalization::functionalize_aten_op<ATEN_OP(
_trilinear)>::call(i1, i2, i3, expand1, expand2, expand3, sumdim,
unroll_dim);
}
at::Tensor LazyNativeFunctions::linalg_pinv(
const at::Tensor &self, const c10::optional<at::Tensor> &atol,
@ -640,10 +686,11 @@ at::Tensor LazyNativeFunctions::linalg_pinv(
}
// functionalize_aten_op can't handle out= ops directly.
// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs.
at::Tensor& LazyNativeFunctions::logsumexp_out(
const at::Tensor& self, at::IntArrayRef dim, bool keepdim,
at::Tensor& out) {
// Instead, we can call the composite kernel from core, and copy and mutations
// back to the inputs.
at::Tensor &LazyNativeFunctions::logsumexp_out(const at::Tensor &self,
at::IntArrayRef dim,
bool keepdim, at::Tensor &out) {
auto self_wrapped = at::functionalization::impl::to_functional_tensor(self);
auto out_wrapped = at::functionalization::impl::to_functional_tensor(out);
// directly call the composite kernel from core.

View File

@ -18,8 +18,7 @@ namespace lazy {
namespace {
hash_t OperandHashes(
const OpList& operands, const c10::ArrayRef<Shape>& shapes,
hash_t OperandHashes(const OpList &operands, const c10::ArrayRef<Shape> &shapes,
const hash_t &seed, bool bakeInSizes) {
hash_t hash = seed;
for (auto &operand : operands) {
@ -38,21 +37,20 @@ hash_t OperandHashes(
} // namespace
// Adds a static hook that is run after every single TorchMlirNode is initialized
// Adds a static hook that is run after every single TorchMlirNode is
// initialized
static std::vector<std::function<void(TorchMlirNode *)>> constructor_hooks;
void TorchMlirNode::addConstructorHook(std::function<void(TorchMlirNode *)> f) {
constructor_hooks.emplace_back(f);
}
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes, size_t num_outputs,
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
std::vector<Shape> &&shapes, size_t num_outputs,
hash_t hash_seed)
: Node(op, operands, std::move(shapes), num_outputs) {
hash_seed = HashCombine(op.hash(), hash_seed);
shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true);
dag_hash_ =
(enableDynamicShape()
dag_hash_ = (enableDynamicShape()
? OperandHashes(operands, this->shapes(), hash_seed, false)
: shape_hash_);
@ -61,28 +59,27 @@ TorchMlirNode::TorchMlirNode(
}
}
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands,
const std::function<Shape()> &shape_fn,
size_t num_outputs, hash_t hash_seed)
: TorchMlirNode(
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
hash_seed) {
addComputedShape(shape_fn);
}
TorchMlirNode::TorchMlirNode(
OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed)
: TorchMlirNode(
op, operands, std::vector<Shape>{}, num_outputs, hash_seed) {}
TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
hash_t hash_seed)
: TorchMlirNode(op, operands, std::vector<Shape>{}, num_outputs,
hash_seed) {}
TorchMlirNode::TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed)
TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
hash_t hash_seed)
: TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {}
hash_t TorchMlirNode::hash() const { return dag_hash_; }
hash_t TorchMlirNode::shapeHash() const { return shape_hash_; }
TorchMlirNode *TorchMlirNode::mlir_node(int index) const {
return dynamic_cast<TorchMlirNode *>(operands_.at(index).get());
}
@ -107,8 +104,9 @@ TorchMlirTensorList::TorchMlirTensorList(OpList values)
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
torch::lazy::TorchMlirOpVector
TorchMlirTensorList::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
std::vector<torch::jit::Value *> tensor_list;
CHECK(!operands().empty());
for (const torch::lazy::Output &operand : operands()) {
@ -140,16 +138,17 @@ TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values)
/*num_outputs=*/1,
/*hash_seed=*/kHashSeed) {}
torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
torch::lazy::TorchMlirOpVector
TorchMlirOptionalTensorList::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
std::vector<torch::jit::Value *> tensor_list;
CHECK(!operands().empty());
for (const torch::lazy::Output &operand : operands()) {
tensor_list.emplace_back(loctx->GetOutputOp(operand));
}
auto graph = function->graph();
auto listnode =
graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list));
auto listnode = graph->insertNode(graph->createList(
c10::OptionalType::create(c10::TensorType::get()), tensor_list));
return {listnode->output()};
}

View File

@ -27,22 +27,21 @@ namespace lazy {
class TORCH_API TorchMlirNode : public torch::lazy::Node {
public:
TorchMlirNode(
OpKind op, OpList operands, std::vector<Shape>&& shapes,
TorchMlirNode(OpKind op, OpList operands, std::vector<Shape> &&shapes,
size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, OpList operands, const std::function<Shape()>& shape_fn,
size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, OpList operands, size_t num_outputs,
TorchMlirNode(OpKind op, OpList operands,
const std::function<Shape()> &shape_fn, size_t num_outputs,
hash_t hash_seed = kHashSeed);
TorchMlirNode(
OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed);
TorchMlirNode(OpKind op, OpList operands, size_t num_outputs,
hash_t hash_seed = kHashSeed);
// Adds a static hook that is run after every single TorchMlirNode is constructed
TorchMlirNode(OpKind op, Shape shape, size_t num_outputs,
hash_t hash_seed = kHashSeed);
// Adds a static hook that is run after every single TorchMlirNode is
// constructed
static void addConstructorHook(std::function<void(TorchMlirNode *)>);
~TorchMlirNode() override = default;
@ -53,8 +52,8 @@ public:
TorchMlirNode *mlir_node(int index) const;
virtual TorchMlirOpVector
Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const;
virtual TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const;
private:
// The hash of the dag WITH size info. Used for shape caching
@ -86,21 +85,22 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode {
TorchMlirTensorList() = delete;
TorchMlirTensorList(OpList values);
torch::lazy::TorchMlirOpVector Lower(
TorchMlirFunction function,
torch::lazy::TorchMlirOpVector
Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const override;
};
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent
// optional tensors, so the output type for this op is !torch.list<optional<vtensor>>.
// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also
// represent optional tensors, so the output type for this op is
// !torch.list<optional<vtensor>>.
struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode {
static OpKind ClassOpKind();
TorchMlirOptionalTensorList() = delete;
TorchMlirOptionalTensorList(OpList values);
torch::lazy::TorchMlirOpVector Lower(
TorchMlirFunction function,
torch::lazy::TorchMlirOpVector
Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const override;
};

View File

@ -31,8 +31,8 @@
namespace torch {
namespace lazy {
TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirFunction function, c10::Symbol sym,
TorchMlirOpVector
LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym,
const std::vector<c10::TypePtr> tensor_types,
const std::vector<torch::jit::NamedValue> &arguments,
const std::vector<torch::jit::NamedValue> &kwarguments) {
@ -43,9 +43,11 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
for (auto arg : arguments) {
torch::jit::Value *value = arg.value(dummy_graph);
if (value->type()->kind() == c10::TypeKind::ListType) {
auto list_element_type = value->type()->cast<c10::ListType>()->getElementType();
auto list_element_type =
value->type()->cast<c10::ListType>()->getElementType();
if (list_element_type->cast<c10::OptionalType>()) {
value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get())));
value->setType(c10::ListType::create(
c10::OptionalType::create(c10::TensorType::get())));
} else {
value->setType(c10::ListType::create(c10::TensorType::get()));
}
@ -61,16 +63,18 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirOpVector results;
if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) {
// Unpack dynamic multi-output operations like aten::split with Tensor[] output type.
// This is required to have consistent input types for multi-output node consumers.
torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size());
// Unpack dynamic multi-output operations like aten::split with Tensor[]
// output type. This is required to have consistent input types for
// multi-output node consumers.
torch::jit::Node *node = function->graph()->createListUnpack(
sv->getValue(), tensor_types.size());
function->graph()->insertNode(node);
for (const auto &output : node->outputs()) {
results.push_back(output);
}
} else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) {
// Op returns multiple values and the number of outputs is static and defined
// by the operation schema.
// Op returns multiple values and the number of outputs is static and
// defined by the operation schema.
const auto tuple_call_result = sv->asTuple({}, *function);
for (const auto &tuple_component : tuple_call_result) {
auto tuple_component_sv =
@ -97,16 +101,15 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
}
// Ensure that we use up all the known tensor type information available.
TORCH_CHECK(
tensor_type_idx == tensor_types.size(), tensor_type_idx,
" known types were injected into jit::Value, but ", tensor_types.size(),
" were provided from lazy::Node!");
TORCH_CHECK(tensor_type_idx == tensor_types.size(), tensor_type_idx,
" known types were injected into jit::Value, but ",
tensor_types.size(), " were provided from lazy::Node!");
return results;
}
TorchMlirOpVector LowerTorchMlirBuiltin(
TorchMlirFunction function, c10::Symbol sym,
TorchMlirOpVector
LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym,
const c10::ArrayRef<Shape> result_shapes,
const std::vector<torch::jit::NamedValue> &arguments,
const std::vector<torch::jit::NamedValue> &kwarguments) {
@ -122,27 +125,27 @@ TorchMlirOpVector LowerTorchMlirBuiltin(
/*requires_grad=*/c10::nullopt));
}
return LowerTorchMlirBuiltin(
function, sym, tensor_types, arguments, kwarguments);
return LowerTorchMlirBuiltin(function, sym, tensor_types, arguments,
kwarguments);
}
TorchMlirOpVector LowerBuiltin(
const torch::lazy::Node* node, TorchMlirFunction function,
TorchMlirOpVector
LowerBuiltin(const torch::lazy::Node *node, TorchMlirFunction function,
const std::vector<torch::jit::NamedValue> &arguments,
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
return LowerTorchMlirBuiltin(
function, node->op().op, node->shapes(), arguments, kwarguments);
return LowerTorchMlirBuiltin(function, node->op().op, node->shapes(),
arguments, kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
TorchMlirOpVector
LowerBuiltin(c10::Symbol sym, const c10::ArrayRef<Shape> result_shapes,
TorchMlirFunction function,
const std::vector<torch::jit::NamedValue> &arguments,
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
return LowerTorchMlirBuiltin(
function, sym, result_shapes, arguments, kwarguments);
return LowerTorchMlirBuiltin(function, sym, result_shapes, arguments,
kwarguments);
}
TorchMlirOpVector LowerBuiltin(
c10::Symbol sym, const std::vector<c10::TypePtr> types,
TorchMlirOpVector
LowerBuiltin(c10::Symbol sym, const std::vector<c10::TypePtr> types,
TorchMlirFunction function,
const std::vector<torch::jit::NamedValue> &arguments,
const std::vector<torch::jit::NamedValue> &kwarguments = {}) {
@ -181,14 +184,14 @@ std::vector<torch::lazy::Shape> compute_shape_copy(c10::TypePtr value_type) {
TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!");
auto scalar_type = tensor_type.scalarType();
TORCH_CHECK(
scalar_type.has_value(), "Unable to copy due to lack of scalar type!");
TORCH_CHECK(scalar_type.has_value(),
"Unable to copy due to lack of scalar type!");
return {Shape(scalar_type.value(), maybe_dims.value())};
}
std::vector<torch::lazy::Shape> compute_shape_slice(
c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end,
int64_t step) {
std::vector<torch::lazy::Shape> compute_shape_slice(c10::TypePtr value_type,
int64_t dim, int64_t start,
int64_t end, int64_t step) {
c10::TensorType &tensor_type = cast_tensor_type(value_type);
auto maybe_dims = get_tensor_type_shape(tensor_type);
@ -217,13 +220,13 @@ std::vector<torch::lazy::Shape> compute_shape_slice(
}
auto scalar_type = tensor_type.scalarType();
TORCH_CHECK(
scalar_type.has_value(), "Unable to slice due to lack of scalar type!");
TORCH_CHECK(scalar_type.has_value(),
"Unable to slice due to lack of scalar type!");
return {Shape(scalar_type.value(), dims)};
}
torch::jit::Value*
GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
torch::jit::Value *GenerateClone(torch::jit::Value *val,
TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> clone_arguments;
clone_arguments.emplace_back(val);
@ -234,20 +237,19 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
return cloned.front();
}
void GenerateCopy(
torch::jit::Value* destination, torch::jit::Value* source,
void GenerateCopy(torch::jit::Value *destination, torch::jit::Value *source,
TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(destination);
arguments.emplace_back(source);
LowerBuiltin(
at::aten::copy_, c10::ArrayRef<Shape>(compute_shape_copy(source->type())),
LowerBuiltin(at::aten::copy_,
c10::ArrayRef<Shape>(compute_shape_copy(source->type())),
function, arguments);
}
torch::jit::Value* GenerateSlice(
torch::jit::Value* base, int64_t dim, int64_t start, int64_t end,
int64_t step, TorchMlirFunction function) {
torch::jit::Value *GenerateSlice(torch::jit::Value *base, int64_t dim,
int64_t start, int64_t end, int64_t step,
TorchMlirFunction function) {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(base);
arguments.emplace_back(dim);
@ -255,10 +257,10 @@ torch::jit::Value* GenerateSlice(
arguments.emplace_back(end);
arguments.emplace_back(step);
TorchMlirOpVector selected = LowerBuiltin(
at::aten::slice,
c10::ArrayRef<Shape>(
compute_shape_slice(base->type(), dim, start, end, step)),
TorchMlirOpVector selected =
LowerBuiltin(at::aten::slice,
c10::ArrayRef<Shape>(compute_shape_slice(base->type(), dim,
start, end, step)),
function, arguments);
TORCH_CHECK_EQ(selected.size(), 1);
return selected.front();
@ -267,8 +269,8 @@ torch::jit::Value* GenerateSlice(
// Node Lowerings
// Default Node Lowering
TorchMlirOpVector TorchMlirNode::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector TorchMlirNode::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
std::vector<torch::jit::NamedValue> arguments;
for (const torch::lazy::Output &output : operands()) {
arguments.emplace_back(loctx->GetOutputOp(output));
@ -280,16 +282,16 @@ TorchMlirOpVector TorchMlirNode::Lower(
// Non-native nodes
TorchMlirOpVector
Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector Cast::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(dtype);
return LowerBuiltin(at::aten::to, shapes(), function, arguments);
}
TorchMlirOpVector DeviceData::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector DeviceData::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
auto infoptr = data_->info();
auto deviceDataInfoPtr =
(torch::lazy::LazyGraphExecutor::DeviceDataInfo *)infoptr;
@ -300,8 +302,8 @@ TorchMlirOpVector DeviceData::Lower(
return {loctx->GetParameter(data_)};
}
TorchMlirOpVector Scalar::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector Scalar::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
auto options =
at::TensorOptions()
.device(torch::lazy::getBackend()->EagerFallbackDeviceType())
@ -309,8 +311,8 @@ TorchMlirOpVector Scalar::Lower(
return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))};
}
TorchMlirOpVector Expand::Lower(
TorchMlirFunction function, TorchMlirLoweringContext* loctx) const {
TorchMlirOpVector Expand::Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const {
std::vector<torch::jit::NamedValue> arguments;
arguments.emplace_back(loctx->GetOutputOp(operand(0)));
arguments.emplace_back(size);

View File

@ -2,16 +2,14 @@
#include <torch/csrc/lazy/core/ir_builder.h>
#include "device_data.h"
#include "../backend_impl.h"
#include "device_data.h"
namespace torch {
namespace lazy {
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
: TorchMlirNode(
ClassOpKind(),
data->shape(),
: TorchMlirNode(ClassOpKind(), data->shape(),
/*num_outputs=*/1,
/*hash_seed=*/static_cast<uint32_t>(101)),
data_(std::move(data)) {
@ -21,9 +19,11 @@ DeviceData::DeviceData(std::shared_ptr<BackendData> data)
void DeviceData::propagate_name() {
if (data_ && name_ != "") {
// Add device data name to backend data
TorchMlirBackendData* mlir_data = dynamic_cast<TorchMlirBackendData*>(data_.get());
TorchMlirBackendData *mlir_data =
dynamic_cast<TorchMlirBackendData *>(data_.get());
TORCH_CHECK(mlir_data);
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
auto *info =
dynamic_cast<TorchMlirBackendData::Info *>(mlir_data->mlir_info());
TORCH_CHECK(info);
info->name = name_;
}

View File

@ -6,15 +6,12 @@
#include <torch/csrc/lazy/backend/backend_data.h>
#include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
namespace torch {
namespace lazy {
class TORCH_API DeviceData : public TorchMlirNode {
public:
static OpKind ClassOpKind() {
return ltc_device_data;
}
static OpKind ClassOpKind() { return ltc_device_data; }
explicit DeviceData(std::shared_ptr<BackendData> data);
@ -31,7 +28,8 @@ class TORCH_API DeviceData : public TorchMlirNode {
void SetData(std::shared_ptr<BackendData> data);
TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override;
TorchMlirOpVector Lower(TorchMlirFunction function,
TorchMlirLoweringContext *loctx) const override;
static const DeviceData *Cast(const Node *node);

View File

@ -15,11 +15,7 @@
namespace torch {
namespace lazy {
Generic::Generic(
OpKind op,
OpList operands,
Shape shape,
size_t num_outputs,
Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs,
hash_t hash_seed)
: TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed),
hash_seed_(hash_seed) {}

View File

@ -24,11 +24,7 @@ namespace lazy {
// Doing the former would limit IR introspection.
class TORCH_API Generic : public TorchMlirNode {
public:
Generic(
OpKind op,
OpList operands,
Shape shape,
size_t num_outputs = 1,
Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1,
hash_t hash_seed = static_cast<uint32_t>(0x5a2d296e9));
private:

View File

@ -17,25 +17,28 @@
namespace torch {
namespace lazy {
// This IR was copied from code-generated output, but the entire _to_copy operator
// cannot be trivially code genereated since it is only desirable to capture IR for
// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke
// the aten/eager fallback necessitating directly implementing the right to(device) behavior
// This IR was copied from code-generated output, but the entire _to_copy
// operator cannot be trivially code genereated since it is only desirable to
// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the
// others it is difficult to even invoke the aten/eager fallback necessitating
// directly implementing the right to(device) behavior
class ToCopy : public torch::lazy::TorchMlirNode {
public:
ToCopy(const torch::lazy::Value& self, const c10::optional<at::ScalarType>& dtype, const c10::optional<at::Layout>& layout, const c10::optional<at::Device>& device, const c10::optional<bool>& pin_memory, const bool& non_blocking, const c10::optional<at::MemoryFormat>& memory_format, std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy),
{self}, std::move(shapes),
ToCopy(const torch::lazy::Value &self,
const c10::optional<at::ScalarType> &dtype,
const c10::optional<at::Layout> &layout,
const c10::optional<at::Device> &device,
const c10::optional<bool> &pin_memory, const bool &non_blocking,
const c10::optional<at::MemoryFormat> &memory_format,
std::vector<torch::lazy::Shape> &&shapes)
: torch::lazy::TorchMlirNode(
torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes),
/* num_outputs */ 1,
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)),
torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking,
memory_format)),
dtype(dtype),
layout(layout),
device(device),
pin_memory(pin_memory),
non_blocking(non_blocking),
memory_format(memory_format) {}
dtype(dtype), layout(layout), device(device), pin_memory(pin_memory),
non_blocking(non_blocking), memory_format(memory_format) {}
std::string ToString() const override {
std::stringstream ss;
@ -69,7 +72,8 @@ class ToCopy : public torch::lazy::TorchMlirNode {
return ss.str();
}
torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function,
torch::lazy::TorchMlirOpVector
Lower(TorchMlirFunction function,
torch::lazy::TorchMlirLoweringContext *loctx) const override {
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
@ -83,11 +87,12 @@ class ToCopy : public torch::lazy::TorchMlirNode {
kwarguments.emplace_back("pin_memory", pin_memory);
kwarguments.emplace_back("non_blocking", non_blocking);
kwarguments.emplace_back("memory_format", memory_format);
torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
torch::lazy::TorchMlirOpVector _to_copy_out =
torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(),
arguments, kwarguments);
TORCH_CHECK_EQ(_to_copy_out.size(), 1);
return _to_copy_out;
}
c10::optional<at::ScalarType> dtype;

View File

@ -27,7 +27,6 @@ std::vector<torch::lazy::Shape> compute_shape_add(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_sub(const at::Tensor &self,
const at::Scalar &other,
const at::Scalar &alpha) {
@ -96,9 +95,8 @@ std::vector<torch::lazy::Shape> compute_shape_quantize_per_channel(
}
std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation,
bool ceil_mode) {
const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride,
at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) {
auto in_sizes = self.sizes().vec();
std::vector<int64_t> dhw(3, 0);
std::vector<int64_t> paddings = padding.vec();
@ -107,17 +105,18 @@ std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices(
std::vector<int64_t> strides = stride.vec();
TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ",
in_sizes);
TORCH_CHECK(kernel_size.size() == 3 &&
stride.size() == 3 &&
padding.size() == 3 &&
dilation.size() == 3, "max_pool3d requires 3D operands, but got ",
kernel_size, stride, padding, dilation);
TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 &&
padding.size() == 3 && dilation.size() == 3,
"max_pool3d requires 3D operands, but got ", kernel_size, stride,
padding, dilation);
int64_t batch = in_sizes[0];
int64_t channel = in_sizes[1]; // NCDHW
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html
for (auto i = 0UL; i < 3; ++i) {
double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] *
(ksizes[i] - 1) - 1) / (double)strides[i] + 1;
double out_size = (in_sizes[2 + i] + 2 * paddings[i] -
dilations[i] * (ksizes[i] - 1) - 1) /
(double)strides[i] +
1;
if (ceil_mode)
dhw[i] = (int64_t)std::ceil(out_size);
else
@ -136,8 +135,9 @@ std::vector<torch::lazy::Shape> compute_shape_max_pool3d_with_indices_backward(
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_mse_loss_backward(
const at::Tensor& grad_output, const at::Tensor& self,
std::vector<torch::lazy::Shape>
compute_shape_mse_loss_backward(const at::Tensor &grad_output,
const at::Tensor &self,
const at::Tensor &target, int64_t reduction) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -147,21 +147,22 @@ std::vector<torch::lazy::Shape> compute_shape_mul(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_var(
const at::Tensor& self, at::OptionalIntArrayRef dim,
std::vector<torch::lazy::Shape>
compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim,
const c10::optional<at::Scalar> &correction, bool keepdim) {
// Result of variance is scalar tensor.
return {Shape(self.scalar_type(), {})};
}
std::vector<torch::lazy::Shape> compute_shape_nan_to_num(
const at::Tensor & self, c10::optional<double> nan,
c10::optional<double> posinf, c10::optional<double> neginf) {
std::vector<torch::lazy::Shape>
compute_shape_nan_to_num(const at::Tensor &self, c10::optional<double> nan,
c10::optional<double> posinf,
c10::optional<double> neginf) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_hardtanh(
const at::Tensor& self, const at::Scalar& min_val,
std::vector<torch::lazy::Shape>
compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val,
const at::Scalar &max_val) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -201,9 +202,9 @@ std::vector<torch::lazy::Shape> compute_shape_where(const at::Tensor& condition,
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bucketize(
const at::Tensor& self, const at::Tensor& boundaries, bool out_int32,
bool right) {
std::vector<torch::lazy::Shape>
compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries,
bool out_int32, bool right) {
auto dtype = out_int32 ? at::kInt : at::kLong;
return {Shape(dtype, self.sizes().vec())};
}
@ -214,8 +215,8 @@ std::vector<torch::lazy::Shape> compute_shape_copy(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_floor_divide(
const at::Tensor& self, const at::Tensor& other) {
std::vector<torch::lazy::Shape>
compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -244,9 +245,10 @@ std::vector<torch::lazy::Shape> compute_shape_native_group_norm(
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_im2col(
const at::Tensor& self, at::IntArrayRef kernel_size,
at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) {
std::vector<torch::lazy::Shape>
compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size,
at::IntArrayRef dilation, at::IntArrayRef padding,
at::IntArrayRef stride) {
auto self_meta = at::native::empty_strided_meta_symint(
self.sym_sizes(), self.sym_strides(),
@ -280,8 +282,8 @@ std::vector<torch::lazy::Shape> compute_shape_native_group_norm_backward(
return shapes;
}
std::vector<torch::lazy::Shape> compute_shape_remainder(
const at::Tensor& self, const at::Scalar& other) {
std::vector<torch::lazy::Shape>
compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -313,20 +315,21 @@ compute_shape_reflection_pad2d(const at::Tensor &self,
return {Shape(self.scalar_type(), out_sizes)};
}
std::vector<torch::lazy::Shape> compute_shape_uniform(
const at::Tensor& self, double from, double to,
std::vector<torch::lazy::Shape>
compute_shape_uniform(const at::Tensor &self, double from, double to,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_normal_functional(
const at::Tensor& self, double mean, double std,
std::vector<torch::lazy::Shape>
compute_shape_normal_functional(const at::Tensor &self, double mean, double std,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_multinomial(
const at::Tensor& self, int64_t num_samples, bool replacement,
std::vector<torch::lazy::Shape>
compute_shape_multinomial(const at::Tensor &self, int64_t num_samples,
bool replacement,
c10::optional<at::Generator> generator) {
// Input tensor can be either 1D or 2D. The last dim of output
// should be 'num_samples'. So the output shape can be either
@ -337,27 +340,30 @@ std::vector<torch::lazy::Shape> compute_shape_multinomial(
return {Shape(at::kLong, ishape)};
}
std::vector<torch::lazy::Shape> compute_shape_eye(
int64_t n, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_eye(int64_t n, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_eye(
int64_t n, int64_t m, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_eye(int64_t n, int64_t m, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_arange(
const at::Scalar& end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_arange(const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
auto out_meta =
at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory);
@ -390,25 +396,28 @@ std::vector<torch::lazy::Shape> compute_shape_full(
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_ones(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_ones(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_zeros(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_zeros(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_empty(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_empty(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory,
c10::optional<at::MemoryFormat> memory_format) {
return {
@ -433,9 +442,10 @@ std::vector<torch::lazy::Shape> compute_shape_fill(const at::Tensor& self,
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_randn(
at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device,
std::vector<torch::lazy::Shape>
compute_shape_randn(at::IntArrayRef size, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
return {
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
@ -457,14 +467,14 @@ std::vector<torch::lazy::Shape> compute_shape_randint(
Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_resize(
const at::Tensor & self, at::IntArrayRef size,
std::vector<torch::lazy::Shape>
compute_shape_resize(const at::Tensor &self, at::IntArrayRef size,
c10::optional<at::MemoryFormat> memory_format) {
return {Shape(self.scalar_type(), size.vec())};
}
std::vector<torch::lazy::Shape> compute_shape_bernoulli(
const at::Tensor& self, const at::Tensor &p,
std::vector<torch::lazy::Shape>
compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
@ -476,17 +486,20 @@ std::vector<torch::lazy::Shape> compute_shape_scalar_tensor(
return {Shape(dtype.value_or(s.type()), c10::ArrayRef<int64_t>{})};
}
std::vector<torch::lazy::Shape> compute_shape_roll(
const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) {
std::vector<torch::lazy::Shape> compute_shape_roll(const at::Tensor &self,
at::IntArrayRef shifts,
at::IntArrayRef dims) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<torch::lazy::Shape> compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
auto out_meta =
at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory);
std::vector<torch::lazy::Shape> compute_shape_linspace(
const at::Scalar &start, const at::Scalar &end, int64_t steps,
c10::optional<at::ScalarType> dtype, c10::optional<at::Layout> layout,
c10::optional<at::Device> device, c10::optional<bool> pin_memory) {
auto out_meta = at::linspace(start, end, steps, dtype, layout,
c10::Device(c10::kMeta), pin_memory);
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
} // namespace lazy
} // namespace torch

View File

@ -14,8 +14,8 @@
namespace torch {
namespace lazy {
at::Tensor CreateFunctionalizedAtenFromLtcTensor(
const LazyTensorPtr& ltc_tensor) {
at::Tensor
CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr &ltc_tensor) {
at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor);
if (!c10::impl::tls_is_dispatch_key_excluded(
c10::DispatchKey::Functionalize) &&

View File

@ -18,7 +18,8 @@ namespace lazy {
// should have explicit tensor functinoalization. Otherwise we can get
// unfanctionalized primitives or in the worst case if we apply inplace
// operations to unfunctionalized tensor it won't be captured in LTC graph.
TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor);
TORCH_API at::Tensor
CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr &ltc_tensor);
} // namespace lazy
} // namespace torch

View File

@ -21,8 +21,8 @@
}
#define UNIMPLEMENTED_FUNCTION_ERROR() \
UNIMPLEMENTED_ERROR( \
"\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__)
UNIMPLEMENTED_ERROR("\n\t" << __FILE__ << ":" << __LINE__ << " " \
<< __PRETTY_FUNCTION__)
#define UNSUPPORTED_ERROR(msg) \
{ \

View File

@ -27,12 +27,10 @@ void ConvertScalarImplicit(std::shared_ptr<Graph>& graph) {
node_type = c10::aten::FloatImplicit;
output_type = FloatType::get();
} else {
throw std::runtime_error(
"Expected isIntegralType or isFloatingType");
throw std::runtime_error("Expected isIntegralType or isFloatingType");
}
Value * output = graph
->create(node_type, {input})
Value *output = graph->create(node_type, {input})
->insertBefore(node)
->output()
->setType(output_type);

View File

@ -1,15 +1,17 @@
#pragma once
#include <string>
#include <sstream>
#include <string>
#include <vector>
template <typename T>
std::ostream& string_join(std::ostream& out, const std::vector<T>& v, const std::string& delimiter) {
std::ostream &string_join(std::ostream &out, const std::vector<T> &v,
const std::string &delimiter) {
size_t i = 0;
for (const T &e : v) {
if ((i++) > 0) { out << delimiter; }
if ((i++) > 0) {
out << delimiter;
}
out << e;
}
return out;
@ -22,10 +24,8 @@ std::string string_join(const std::vector<T>& v, const std::string& delimiter) {
return joined.str();
}
inline std::vector<std::string> string_split(
const std::string& str,
const std::string& sep
) {
inline std::vector<std::string> string_split(const std::string &str,
const std::string &sep) {
std::vector<std::string> tokens;
std::size_t pos1 = str.find_first_not_of(sep);
while (pos1 != std::string::npos) {

View File

@ -14,7 +14,8 @@ static T GetEnv(const std::string& name, const T& default_value = T(0)) {
return T(std::atoi(env));
}
static std::string GetEnvString(const std::string& name, const std::string& default_value) {
static std::string GetEnvString(const std::string &name,
const std::string &default_value) {
const char *env = std::getenv(name.c_str());
if (!env) {
return default_value;

View File

@ -3,7 +3,6 @@
#include "../generated/LazyIr.h"
#include "../mlir_node.h"
namespace torch {
namespace lazy {
@ -15,9 +14,12 @@ bool is_detach_copy(const torch::lazy::Value& value) {
}
torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *node) {
if (!node) { return nullptr; }
if (!node) {
return nullptr;
}
torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<torch::lazy::TorchMlirNode*>(node);
torch::lazy::TorchMlirNode *mlir_node =
dynamic_cast<torch::lazy::TorchMlirNode *>(node);
while (mlir_node && is_detach_copy(mlir_node)) {
mlir_node = mlir_node->mlir_node(0);
}
@ -27,10 +29,14 @@ torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) {
return mlir_node;
}
const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) {
if (!node) { return nullptr; }
const torch::lazy::Node *
extract_non_detach_copy_node(const torch::lazy::Node *node) {
if (!node) {
return nullptr;
}
const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast<const torch::lazy::TorchMlirNode*>(node);
const torch::lazy::TorchMlirNode *mlir_node =
dynamic_cast<const torch::lazy::TorchMlirNode *>(node);
while (mlir_node && is_detach_copy(mlir_node)) {
mlir_node = mlir_node->mlir_node(0);
}
@ -40,7 +46,6 @@ const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* n
return mlir_node;
}
torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *node) {
if (!node) {
return nullptr;
@ -68,14 +73,15 @@ torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) {
return device_data_cast(value.node.get());
}
torch::lazy::DeviceData* device_data_cast(
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device
) {
torch::lazy::DeviceData *
device_data_cast(const at::Tensor &tensor,
c10::optional<torch::lazy::BackendDevice> device) {
if (!device) {
device = torch::lazy::GetBackendDevice(tensor);
}
TORCH_CHECK(device);
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
torch::lazy::LazyTensorPtr lazy_tensor =
torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
if (lazy_tensor) {
return device_data_cast(lazy_tensor->GetIrValue());
}

View File

@ -12,14 +12,17 @@ TORCH_API bool is_detach_copy(const torch::lazy::Node*);
TORCH_API bool is_detach_copy(const torch::lazy::Value &);
TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *);
TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*);
TORCH_API const torch::lazy::Node *
extract_non_detach_copy_node(const torch::lazy::Node *);
TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *);
TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*);
TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value);
TORCH_API const torch::lazy::DeviceData *
device_data_cast(const torch::lazy::Node *);
TORCH_API torch::lazy::DeviceData *
device_data_cast(const torch::lazy::Value &value);
TORCH_API torch::lazy::DeviceData *device_data_cast(
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
);
const at::Tensor &tensor,
c10::optional<torch::lazy::BackendDevice> device = c10::nullopt);
} // namespace lazy
} // namespace torch

View File

@ -73,10 +73,8 @@ public:
// Vendor backend specific lowering can be exec here before returning.
for (const auto& instance : instances) {
TORCH_CHECK(
instance->in_mark_step,
"Compile outside of mark step:\n",
GetComputationBackendText(instance)
);
instance->in_mark_step, "Compile outside of mark step:\n",
GetComputationBackendText(instance));
// Store computation instance for external access after compilation.
GetLatestComputation() = instance;
}
@ -114,12 +112,13 @@ public:
// Convert any lazy devices to cpu devices to ensure
// that the values are actually computed
if (node->outputs().size() == 1 &&
node->output()->type()->kind() ==
c10::TypeKind::DeviceObjType) {
node->output()->type()->kind() == c10::TypeKind::DeviceObjType) {
auto value_sym = torch::jit::Symbol::attr("value");
TORCH_CHECK(node->hasAttribute(value_sym),
TORCH_CHECK(
node->hasAttribute(value_sym),
"Expected node to have 'value' attribute.");
TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s,
TORCH_CHECK(
node->kindOf(value_sym) == torch::jit::AttributeKind::s,
"Expected 'value' attribute to be a string.");
if (beginswith(node->s(value_sym), "lazy")) {
node->s_(value_sym, "cpu");
@ -132,7 +131,8 @@ public:
for (const auto& argument : arguments) {
const auto mlir_data =
std::static_pointer_cast<TorchMlirBackendData>(argument);
auto* info = dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
auto* info =
dynamic_cast<TorchMlirBackendData::Info*>(mlir_data->mlir_info());
TORCH_CHECK(info);
if (info->scalar.has_value()) {
stack.emplace_back(info->scalar.value());

View File

@ -8,8 +8,8 @@
//===----------------------------------------------------------------------===//
#include "torch/csrc/jit/python/pybind.h"
#include "torch/csrc/lazy/core/config.h"
#include "torch/csrc/lazy/backend/backend_interface.h"
#include "torch/csrc/lazy/core/config.h"
#include <base_lazy_backend/mlir_lowering_context.h>
#include <base_lazy_backend/utils/string_utils.h>
@ -82,9 +82,11 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) {
torch::lazy::GetLatestComputation().get());
return py::cast(computation);
});
m.def("set_parameter_name",
m.def(
"set_parameter_name",
[](const at::Tensor& tensor, const std::string& name) -> bool {
torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor);
torch::lazy::DeviceData* ir_node =
torch::lazy::device_data_cast(tensor);
if (ir_node) {
ir_node->SetName(name);
return true;