mirror of https://github.com/llvm/torch-mlir
debugging and hacks
parent
f70978c034
commit
1e3df9c943
|
@ -366,11 +366,7 @@ public:
|
||||||
|
|
||||||
auto [inputShape, outputShape] =
|
auto [inputShape, outputShape] =
|
||||||
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
|
||||||
llvm::errs() << "inputShape: ";
|
|
||||||
llvm::interleaveComma(inputShape, llvm::errs());
|
|
||||||
|
|
||||||
llvm::errs() << " \noutput shape: ";
|
|
||||||
llvm::interleaveComma(outputShape, llvm::errs());
|
|
||||||
|
|
||||||
// Currently, we only handle the cases where each dimension is either
|
// Currently, we only handle the cases where each dimension is either
|
||||||
// being expanded or collapsed. We do not handle cases where it's neither
|
// being expanded or collapsed. We do not handle cases where it's neither
|
||||||
|
@ -395,10 +391,6 @@ public:
|
||||||
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
|
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
|
||||||
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
|
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
|
||||||
|
|
||||||
llvm::errs() << "inputDim: ";
|
|
||||||
llvm::errs() << inputDim;
|
|
||||||
llvm::errs() << " \noutput Dim: ";
|
|
||||||
llvm::errs() << outputDim;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Mark the end of the input/output shapes
|
// Mark the end of the input/output shapes
|
||||||
|
@ -440,10 +432,10 @@ public:
|
||||||
// Used for ensuring that we don't have an ambiguous expansion
|
// Used for ensuring that we don't have an ambiguous expansion
|
||||||
bool assumedDynamicDimNotSplit = false;
|
bool assumedDynamicDimNotSplit = false;
|
||||||
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
|
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
|
||||||
auto inputShapeSlice = // 1, ?, 4096
|
auto inputShapeSlice = // ?, 4096
|
||||||
MutableArrayRef<int64_t>(inputShape)
|
MutableArrayRef<int64_t>(inputShape)
|
||||||
.slice(inputDim, nextUnchangedInput - inputDim);
|
.slice(inputDim, nextUnchangedInput - inputDim);
|
||||||
auto outputShapeSlice = // ?, 4096
|
auto outputShapeSlice = //1, ?, 4096
|
||||||
MutableArrayRef<int64_t>(outputShape)
|
MutableArrayRef<int64_t>(outputShape)
|
||||||
.slice(outputDim, nextUnchangedOutput - outputDim);
|
.slice(outputDim, nextUnchangedOutput - outputDim);
|
||||||
SmallVector<int64_t> inputSliceIndices;
|
SmallVector<int64_t> inputSliceIndices;
|
||||||
|
@ -483,13 +475,21 @@ public:
|
||||||
/// known to have the same number of elements.
|
/// known to have the same number of elements.
|
||||||
} else if (inputShapeSlice[0] == kUnknownSize) {
|
} else if (inputShapeSlice[0] == kUnknownSize) {
|
||||||
// If the input is dynamic, assume it is not split
|
// If the input is dynamic, assume it is not split
|
||||||
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
|
// Elide any unit dims in the output and checkDimEqual on dynamic dims
|
||||||
outputSizeInt[outputDim]);
|
|
||||||
|
int64_t idx = 0;
|
||||||
|
while (idx < outputShapeSlice.size() && outputShapeSlice[idx] == 1) {
|
||||||
|
outputSliceIndices.push_back(idx++);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
|
||||||
|
// outputSizeInt[idx]);
|
||||||
// If output dimension is not dynamic, improve static information of
|
// If output dimension is not dynamic, improve static information of
|
||||||
// input
|
// input
|
||||||
inputShape[inputDim] = outputShape[outputDim];
|
inputShape[inputDim] = outputShape[idx];
|
||||||
inputSliceIndices.push_back(0);
|
inputSliceIndices.push_back(0);
|
||||||
outputSliceIndices.push_back(0);
|
outputSliceIndices.push_back(idx);
|
||||||
assumedDynamicDimNotSplit = true;
|
assumedDynamicDimNotSplit = true;
|
||||||
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
|
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
|
||||||
int64_t idx = 0;
|
int64_t idx = 0;
|
||||||
|
|
|
@ -1543,6 +1543,8 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
llvm::errs() << "dbg decompose fill_scalar\n";
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = op.getType().cast<BaseTensorType>();
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
|
@ -4429,13 +4431,18 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
SmallVector<int64_t> sizeListInts, strideListInts;
|
SmallVector<int64_t> sizeListInts, strideListInts;
|
||||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
|
llvm::errs() << "dbg Try decompose empty strided\n";
|
||||||
return rewriter.notifyMatchFailure(
|
//if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))){
|
||||||
op, "all size list elements must be constant ints");
|
|
||||||
if (!matchPattern(op.getStride(),
|
// llvm::errs() << "dbg all size list ele const ints\n";
|
||||||
m_TorchListOfConstantInts(strideListInts)))
|
// return rewriter.notifyMatchFailure(
|
||||||
return rewriter.notifyMatchFailure(
|
// op, "all size list elements must be constant ints");
|
||||||
op, "all stride list elements must be constant ints");
|
//}
|
||||||
|
//if (!matchPattern(op.getStride(),
|
||||||
|
// m_TorchListOfConstantInts(strideListInts))){
|
||||||
|
// llvm::errs() << "dbg all stride list ele const ints\n";
|
||||||
|
// return rewriter.notifyMatchFailure(
|
||||||
|
// op, "all stride list elements must be constant ints");}
|
||||||
|
|
||||||
// We only support the cases with default stride values.
|
// We only support the cases with default stride values.
|
||||||
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
|
||||||
|
@ -4451,16 +4458,18 @@ public:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!isDefaultStride)
|
if (!isDefaultStride){
|
||||||
|
llvm::errs() << "dbg non default strides\n";
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only default strides supported for new_empty_strided op");
|
op, "only default strides supported for new_empty_strided op");
|
||||||
|
}
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
|
||||||
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
|
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
|
||||||
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
op.getPinMemory(), /*memoryFormat=*/noneVal);
|
||||||
|
|
||||||
|
llvm::errs() << "dbg success\n";
|
||||||
return success();
|
return success();
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -227,12 +227,14 @@ public:
|
||||||
if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>())
|
if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>())
|
||||||
return rewriter.notifyMatchFailure(op, "is not trailing_ variant");
|
return rewriter.notifyMatchFailure(op, "is not trailing_ variant");
|
||||||
|
|
||||||
|
llvm::errs() << "dbg op name: " << op->getName().getStringRef() << "\n";
|
||||||
SmallVector<StringRef> fragments;
|
SmallVector<StringRef> fragments;
|
||||||
llvm::SplitString(op->getName().getStringRef(), fragments, ".");
|
llvm::SplitString(op->getName().getStringRef(), fragments, ".");
|
||||||
assert(fragments.size() >= 3 && fragments[2].endswith("_") &&
|
assert(fragments.size() >= 3 && fragments[2].endswith("_") &&
|
||||||
"IsTrailingUnderscoreInplaceVariant incorrectly applied");
|
"IsTrailingUnderscoreInplaceVariant incorrectly applied");
|
||||||
fragments[2] = fragments[2].drop_back();
|
fragments[2] = fragments[2].drop_back();
|
||||||
std::string noUnderscoreName = llvm::join(fragments, ".");
|
std::string noUnderscoreName = llvm::join(fragments, ".");
|
||||||
|
llvm::errs() << "dbg fragments good\n";
|
||||||
|
|
||||||
OperationState state(op->getLoc(), noUnderscoreName);
|
OperationState state(op->getLoc(), noUnderscoreName);
|
||||||
state.addTypes(op->getResultTypes());
|
state.addTypes(op->getResultTypes());
|
||||||
|
@ -241,14 +243,19 @@ public:
|
||||||
// Note: No successors or regions. Torch JIT operators don't have any.
|
// Note: No successors or regions. Torch JIT operators don't have any.
|
||||||
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
|
||||||
"Torch JIT operators shouldn't have regions or successors");
|
"Torch JIT operators shouldn't have regions or successors");
|
||||||
|
llvm::errs() << "dbg good regions, good successors\n";
|
||||||
|
|
||||||
Operation *newOp = rewriter.create(state);
|
Operation *newOp = rewriter.create(state);
|
||||||
auto tensor =
|
auto tensor =
|
||||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||||
|
|
||||||
|
llvm::errs() << "dbg create copy_to_value_tensor\n";
|
||||||
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
|
||||||
op->getOperand(0));
|
op->getOperand(0));
|
||||||
|
|
||||||
rewriter.replaceOp(op, op->getOperand(0));
|
rewriter.replaceOp(op, op->getOperand(0));
|
||||||
|
|
||||||
|
llvm::errs() << "dbg success\n";
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue