diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 452215b35..c075babf7 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -366,11 +366,7 @@ public: auto [inputShape, outputShape] = getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape()); - llvm::errs() << "inputShape: "; - llvm::interleaveComma(inputShape, llvm::errs()); - llvm::errs() << " \noutput shape: "; - llvm::interleaveComma(outputShape, llvm::errs()); // Currently, we only handle the cases where each dimension is either // being expanded or collapsed. We do not handle cases where it's neither @@ -395,10 +391,6 @@ public: m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { unchangedDims.push_back(std::make_pair(inputDim, outputDim)); - llvm::errs() << "inputDim: "; - llvm::errs() << inputDim; - llvm::errs() << " \noutput Dim: "; - llvm::errs() << outputDim; } } // Mark the end of the input/output shapes @@ -440,10 +432,10 @@ public: // Used for ensuring that we don't have an ambiguous expansion bool assumedDynamicDimNotSplit = false; while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { - auto inputShapeSlice = // 1, ?, 4096 + auto inputShapeSlice = // ?, 4096 MutableArrayRef(inputShape) .slice(inputDim, nextUnchangedInput - inputDim); - auto outputShapeSlice = // ?, 4096 + auto outputShapeSlice = //1, ?, 4096 MutableArrayRef(outputShape) .slice(outputDim, nextUnchangedOutput - outputDim); SmallVector inputSliceIndices; @@ -483,13 +475,21 @@ public: /// known to have the same number of elements. } else if (inputShapeSlice[0] == kUnknownSize) { // If the input is dynamic, assume it is not split - checkDimEqualHelper(rewriter, loc, inputSize[inputDim], - outputSizeInt[outputDim]); + // Elide any unit dims in the output and checkDimEqual on dynamic dims + + int64_t idx = 0; + while (idx < outputShapeSlice.size() && outputShapeSlice[idx] == 1) { + outputSliceIndices.push_back(idx++); + } + + +//checkDimEqualHelper(rewriter, loc, inputSize[inputDim], + // outputSizeInt[idx]); // If output dimension is not dynamic, improve static information of // input - inputShape[inputDim] = outputShape[outputDim]; + inputShape[inputDim] = outputShape[idx]; inputSliceIndices.push_back(0); - outputSliceIndices.push_back(0); + outputSliceIndices.push_back(idx); assumedDynamicDimNotSplit = true; } else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) { int64_t idx = 0; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6136db092..58776d0bb 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1543,6 +1543,8 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, PatternRewriter &rewriter) const override { + + llvm::errs() << "dbg decompose fill_scalar\n"; Location loc = op.getLoc(); auto resType = op.getType().cast(); if (!resType.hasDtype()) { @@ -4429,13 +4431,18 @@ public: LogicalResult matchAndRewrite(AtenEmptyStridedOp op, PatternRewriter &rewriter) const override { SmallVector sizeListInts, strideListInts; - if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) - return rewriter.notifyMatchFailure( - op, "all size list elements must be constant ints"); - if (!matchPattern(op.getStride(), - m_TorchListOfConstantInts(strideListInts))) - return rewriter.notifyMatchFailure( - op, "all stride list elements must be constant ints"); + llvm::errs() << "dbg Try decompose empty strided\n"; + //if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))){ + + // llvm::errs() << "dbg all size list ele const ints\n"; + // return rewriter.notifyMatchFailure( + // op, "all size list elements must be constant ints"); + //} + //if (!matchPattern(op.getStride(), + // m_TorchListOfConstantInts(strideListInts))){ + // llvm::errs() << "dbg all stride list ele const ints\n"; + // return rewriter.notifyMatchFailure( + // op, "all stride list elements must be constant ints");} // We only support the cases with default stride values. // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) @@ -4451,16 +4458,18 @@ public: break; } } - if (!isDefaultStride) + if (!isDefaultStride){ + llvm::errs() << "dbg non default strides\n"; return rewriter.notifyMatchFailure( op, "only default strides supported for new_empty_strided op"); - + } Value noneVal = rewriter.create(op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); + llvm::errs() << "dbg success\n"; return success(); diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 3b30b8a82..e7db5f678 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -227,12 +227,14 @@ public: if (!op->hasTrait()) return rewriter.notifyMatchFailure(op, "is not trailing_ variant"); + llvm::errs() << "dbg op name: " << op->getName().getStringRef() << "\n"; SmallVector fragments; llvm::SplitString(op->getName().getStringRef(), fragments, "."); assert(fragments.size() >= 3 && fragments[2].endswith("_") && "IsTrailingUnderscoreInplaceVariant incorrectly applied"); fragments[2] = fragments[2].drop_back(); std::string noUnderscoreName = llvm::join(fragments, "."); + llvm::errs() << "dbg fragments good\n"; OperationState state(op->getLoc(), noUnderscoreName); state.addTypes(op->getResultTypes()); @@ -241,14 +243,19 @@ public: // Note: No successors or regions. Torch JIT operators don't have any. assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && "Torch JIT operators shouldn't have regions or successors"); + llvm::errs() << "dbg good regions, good successors\n"; Operation *newOp = rewriter.create(state); auto tensor = rewriter.create(op->getLoc(), newOp->getResult(0)); + + llvm::errs() << "dbg create copy_to_value_tensor\n"; createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0)); + rewriter.replaceOp(op, op->getOperand(0)); + llvm::errs() << "dbg success\n"; return success(); } };