view hack

int_view_hack
dan 2024-02-07 03:49:41 +00:00
parent bfcf93ea21
commit 484ee82c28
1 changed files with 12 additions and 4 deletions

View File

@ -741,15 +741,23 @@ public:
// `aten.view`. // `aten.view`.
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>> static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
getInputAndOutputShape(Value inputTorchTensor, getInputAndOutputShape(Value inputTorchTensor,
SmallVector<Value> outputSizeTorchInt) { SmallVector<Value> outputSizeTorchInt, ArrayRef<int64_t> resultTypeShape) {
SmallVector<int64_t> inputShape( SmallVector<int64_t> inputShape(
inputTorchTensor.getType().cast<BaseTensorType>().getSizes()); inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize); SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
for (auto [outputDim, outputDimSize] : // for (auto [outputDim, outputDimSize] :
llvm::enumerate(outputSizeTorchInt)) { // llvm::enumerate(outputSizeTorchInt)) {
for (const auto &it : llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) {
int64_t outputDim = it.index();
Value outputDimSize = std::get<0>(it.value());
int64_t resultShapeSize = std::get<1>(it.value());
int64_t inputDim; int64_t inputDim;
int64_t outputDimSizeInt; int64_t outputDimSizeInt;
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim // Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
if (resultShapeSize >= 0) {
outputShape[outputDim] = resultShapeSize;
continue;
}
if (matchPattern(outputDimSize, if (matchPattern(outputDimSize,
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) { m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
outputShape[outputDim] = inputShape[inputDim]; outputShape[outputDim] = inputShape[inputDim];
@ -813,7 +821,7 @@ public:
} }
auto [inputShape, outputShape] = auto [inputShape, outputShape] =
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt); getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
// Currently, we only handle the cases where each dimension is either // Currently, we only handle the cases where each dimension is either
// being expanded or collapsed. We do not handle cases where it's neither // being expanded or collapsed. We do not handle cases where it's neither