debugging and hacks

llama2_wip
dan 2023-09-27 17:33:34 +00:00
parent f70978c034
commit 1e3df9c943
3 changed files with 39 additions and 23 deletions

View File

@ -366,11 +366,7 @@ public:
auto [inputShape, outputShape] =
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
llvm::errs() << "inputShape: ";
llvm::interleaveComma(inputShape, llvm::errs());
llvm::errs() << " \noutput shape: ";
llvm::interleaveComma(outputShape, llvm::errs());
// Currently, we only handle the cases where each dimension is either
// being expanded or collapsed. We do not handle cases where it's neither
@ -395,10 +391,6 @@ public:
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
unchangedDims.push_back(std::make_pair(inputDim, outputDim));
llvm::errs() << "inputDim: ";
llvm::errs() << inputDim;
llvm::errs() << " \noutput Dim: ";
llvm::errs() << outputDim;
}
}
// Mark the end of the input/output shapes
@ -440,10 +432,10 @@ public:
// Used for ensuring that we don't have an ambiguous expansion
bool assumedDynamicDimNotSplit = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
auto inputShapeSlice = // 1, ?, 4096
auto inputShapeSlice = // ?, 4096
MutableArrayRef<int64_t>(inputShape)
.slice(inputDim, nextUnchangedInput - inputDim);
auto outputShapeSlice = // ?, 4096
auto outputShapeSlice = //1, ?, 4096
MutableArrayRef<int64_t>(outputShape)
.slice(outputDim, nextUnchangedOutput - outputDim);
SmallVector<int64_t> inputSliceIndices;
@ -483,13 +475,21 @@ public:
/// known to have the same number of elements.
} else if (inputShapeSlice[0] == kUnknownSize) {
// If the input is dynamic, assume it is not split
checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
outputSizeInt[outputDim]);
// Elide any unit dims in the output and checkDimEqual on dynamic dims
int64_t idx = 0;
while (idx < outputShapeSlice.size() && outputShapeSlice[idx] == 1) {
outputSliceIndices.push_back(idx++);
}
//checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
// outputSizeInt[idx]);
// If output dimension is not dynamic, improve static information of
// input
inputShape[inputDim] = outputShape[outputDim];
inputShape[inputDim] = outputShape[idx];
inputSliceIndices.push_back(0);
outputSliceIndices.push_back(0);
outputSliceIndices.push_back(idx);
assumedDynamicDimNotSplit = true;
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
int64_t idx = 0;

View File

@ -1543,6 +1543,8 @@ public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override {
llvm::errs() << "dbg decompose fill_scalar\n";
Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
@ -4429,13 +4431,18 @@ public:
LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts)))
return rewriter.notifyMatchFailure(
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(),
m_TorchListOfConstantInts(strideListInts)))
return rewriter.notifyMatchFailure(
op, "all stride list elements must be constant ints");
llvm::errs() << "dbg Try decompose empty strided\n";
//if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))){
// llvm::errs() << "dbg all size list ele const ints\n";
// return rewriter.notifyMatchFailure(
// op, "all size list elements must be constant ints");
//}
//if (!matchPattern(op.getStride(),
// m_TorchListOfConstantInts(strideListInts))){
// llvm::errs() << "dbg all stride list ele const ints\n";
// return rewriter.notifyMatchFailure(
// op, "all stride list elements must be constant ints");}
// We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
@ -4451,16 +4458,18 @@ public:
break;
}
}
if (!isDefaultStride)
if (!isDefaultStride){
llvm::errs() << "dbg non default strides\n";
return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op");
}
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
llvm::errs() << "dbg success\n";
return success();

View File

@ -227,12 +227,14 @@ public:
if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>())
return rewriter.notifyMatchFailure(op, "is not trailing_ variant");
llvm::errs() << "dbg op name: " << op->getName().getStringRef() << "\n";
SmallVector<StringRef> fragments;
llvm::SplitString(op->getName().getStringRef(), fragments, ".");
assert(fragments.size() >= 3 && fragments[2].endswith("_") &&
"IsTrailingUnderscoreInplaceVariant incorrectly applied");
fragments[2] = fragments[2].drop_back();
std::string noUnderscoreName = llvm::join(fragments, ".");
llvm::errs() << "dbg fragments good\n";
OperationState state(op->getLoc(), noUnderscoreName);
state.addTypes(op->getResultTypes());
@ -241,14 +243,19 @@ public:
// Note: No successors or regions. Torch JIT operators don't have any.
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
"Torch JIT operators shouldn't have regions or successors");
llvm::errs() << "dbg good regions, good successors\n";
Operation *newOp = rewriter.create(state);
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
llvm::errs() << "dbg create copy_to_value_tensor\n";
createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
llvm::errs() << "dbg success\n";
return success();
}
};