mirror of https://github.com/llvm/torch-mlir
view hack
parent
bfcf93ea21
commit
484ee82c28
|
@ -741,15 +741,23 @@ public:
|
||||||
// `aten.view`.
|
// `aten.view`.
|
||||||
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
|
static std::pair<SmallVector<int64_t>, SmallVector<int64_t>>
|
||||||
getInputAndOutputShape(Value inputTorchTensor,
|
getInputAndOutputShape(Value inputTorchTensor,
|
||||||
SmallVector<Value> outputSizeTorchInt) {
|
SmallVector<Value> outputSizeTorchInt, ArrayRef<int64_t> resultTypeShape) {
|
||||||
SmallVector<int64_t> inputShape(
|
SmallVector<int64_t> inputShape(
|
||||||
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
||||||
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
||||||
for (auto [outputDim, outputDimSize] :
|
// for (auto [outputDim, outputDimSize] :
|
||||||
llvm::enumerate(outputSizeTorchInt)) {
|
// llvm::enumerate(outputSizeTorchInt)) {
|
||||||
|
for (const auto &it : llvm::enumerate(llvm::zip_equal(outputSizeTorchInt, resultTypeShape))) {
|
||||||
|
int64_t outputDim = it.index();
|
||||||
|
Value outputDimSize = std::get<0>(it.value());
|
||||||
|
int64_t resultShapeSize = std::get<1>(it.value());
|
||||||
int64_t inputDim;
|
int64_t inputDim;
|
||||||
int64_t outputDimSizeInt;
|
int64_t outputDimSizeInt;
|
||||||
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
// Match torch.aten.size.int(inputTensor, inputDim) with constant inputDim
|
||||||
|
if (resultShapeSize >= 0) {
|
||||||
|
outputShape[outputDim] = resultShapeSize;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (matchPattern(outputDimSize,
|
if (matchPattern(outputDimSize,
|
||||||
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
|
m_TorchTensorSizeInt(inputTorchTensor, &inputDim))) {
|
||||||
outputShape[outputDim] = inputShape[inputDim];
|
outputShape[outputDim] = inputShape[inputDim];
|
||||||
|
@ -813,7 +821,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [inputShape, outputShape] =
|
auto [inputShape, outputShape] =
|
||||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt);
|
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
||||||
|
|
||||||
// Currently, we only handle the cases where each dimension is either
|
// Currently, we only handle the cases where each dimension is either
|
||||||
// being expanded or collapsed. We do not handle cases where it's neither
|
// being expanded or collapsed. We do not handle cases where it's neither
|
||||||
|
|
Loading…
Reference in New Issue