mirror of https://github.com/llvm/torch-mlir
debugging and hacks
parent
f70978c034
commit
1e3df9c943
|
@ -366,11 +366,7 @@ public:
|
|||
|
||||
auto [inputShape, outputShape] =
|
||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
||||
llvm::errs() << "inputShape: ";
|
||||
llvm::interleaveComma(inputShape, llvm::errs());
|
||||
|
||||
llvm::errs() << " \noutput shape: ";
|
||||
llvm::interleaveComma(outputShape, llvm::errs());
|
||||
|
||||
// Currently, we only handle the cases where each dimension is either
|
||||
// being expanded or collapsed. We do not handle cases where it's neither
|
||||
|
@ -395,10 +391,6 @@ public:
|
|||
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
|
||||
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
|
||||
|
||||
llvm::errs() << "inputDim: ";
|
||||
llvm::errs() << inputDim;
|
||||
llvm::errs() << " \noutput Dim: ";
|
||||
llvm::errs() << outputDim;
|
||||
}
|
||||
}
|
||||
// Mark the end of the input/output shapes
|
||||
|
@ -440,10 +432,10 @@ public:
|
|||
// Used for ensuring that we don't have an ambiguous expansion
|
||||
bool assumedDynamicDimNotSplit = false;
|
||||
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
|
||||
auto inputShapeSlice = // 1, ?, 4096
|
||||
auto inputShapeSlice = // ?, 4096
|
||||
MutableArrayRef<int64_t>(inputShape)
|
||||
.slice(inputDim, nextUnchangedInput - inputDim);
|
||||
auto outputShapeSlice = // ?, 4096
|
||||
auto outputShapeSlice = //1, ?, 4096
|
||||
MutableArrayRef<int64_t>(outputShape)
|
||||
.slice(outputDim, nextUnchangedOutput - outputDim);
|
||||
SmallVector<int64_t> inputSliceIndices;
|
||||
|
@ -483,13 +475,21 @@ public:
|
|||
/// known to have the same number of elements.
|
||||
} else if (inputShapeSlice[0] == kUnknownSize) {
|
||||
// If the input is dynamic, assume it is not split
|
||||
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
|
||||
outputSizeInt[outputDim]);
|
||||
// Elide any unit dims in the output and checkDimEqual on dynamic dims
|
||||
|
||||
int64_t idx = 0;
|
||||
while (idx < outputShapeSlice.size() && outputShapeSlice[idx] == 1) {
|
||||
outputSliceIndices.push_back(idx++);
|
||||
}
|
||||
|
||||
|
||||
//checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
|
||||
// outputSizeInt[idx]);
|
||||
// If output dimension is not dynamic, improve static information of
|
||||
// input
|
||||
inputShape[inputDim] = outputShape[outputDim];
|
||||
inputShape[inputDim] = outputShape[idx];
|
||||
inputSliceIndices.push_back(0);
|
||||
outputSliceIndices.push_back(0);
|
||||
outputSliceIndices.push_back(idx);
|
||||
assumedDynamicDimNotSplit = true;
|
||||
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
|
||||
int64_t idx = 0;
|
||||
|
|
|
@ -1543,6 +1543,8 @@ public:
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
llvm::errs() << "dbg decompose fill_scalar\n";
|
||||
Location loc = op.getLoc();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
if (!resType.hasDtype()) {
|
||||
|
@ -4429,13 +4431,18 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all size list elements must be constant ints");
|
||||
if (!matchPattern(op.getStride(),
|
||||
m_TorchListOfConstantInts(strideListInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "all stride list elements must be constant ints");
|
||||
llvm::errs() << "dbg Try decompose empty strided\n";
|
||||
//if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))){
|
||||
|
||||
// llvm::errs() << "dbg all size list ele const ints\n";
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "all size list elements must be constant ints");
|
||||
//}
|
||||
//if (!matchPattern(op.getStride(),
|
||||
// m_TorchListOfConstantInts(strideListInts))){
|
||||
// llvm::errs() << "dbg all stride list ele const ints\n";
|
||||
// return rewriter.notifyMatchFailure(
|
||||
// op, "all stride list elements must be constant ints");}
|
||||
|
||||
// We only support the cases with default stride values.
|
||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
||||
|
@ -4451,16 +4458,18 @@ public:
|
|||
break;
|
||||
}
|
||||
}
|
||||
if (!isDefaultStride)
|
||||
if (!isDefaultStride){
|
||||
llvm::errs() << "dbg non default strides\n";
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only default strides supported for new_empty_strided op");
|
||||
|
||||
}
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
|
||||
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
||||
|
||||
llvm::errs() << "dbg success\n";
|
||||
return success();
|
||||
|
||||
|
||||
|
|
|
@ -227,12 +227,14 @@ public:
|
|||
if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>())
|
||||
return rewriter.notifyMatchFailure(op, "is not trailing_ variant");
|
||||
|
||||
llvm::errs() << "dbg op name: " << op->getName().getStringRef() << "\n";
|
||||
SmallVector<StringRef> fragments;
|
||||
llvm::SplitString(op->getName().getStringRef(), fragments, ".");
|
||||
assert(fragments.size() >= 3 && fragments[2].endswith("_") &&
|
||||
"IsTrailingUnderscoreInplaceVariant incorrectly applied");
|
||||
fragments[2] = fragments[2].drop_back();
|
||||
std::string noUnderscoreName = llvm::join(fragments, ".");
|
||||
llvm::errs() << "dbg fragments good\n";
|
||||
|
||||
OperationState state(op->getLoc(), noUnderscoreName);
|
||||
state.addTypes(op->getResultTypes());
|
||||
|
@ -241,14 +243,19 @@ public:
|
|||
// Note: No successors or regions. Torch JIT operators don't have any.
|
||||
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
||||
"Torch JIT operators shouldn't have regions or successors");
|
||||
llvm::errs() << "dbg good regions, good successors\n";
|
||||
|
||||
Operation *newOp = rewriter.create(state);
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
|
||||
llvm::errs() << "dbg create copy_to_value_tensor\n";
|
||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||
op->getOperand(0));
|
||||
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
|
||||
llvm::errs() << "dbg success\n";
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue