debugging and hacks

llama2_wip
dan 2023-09-27 17:33:34 +00:00
parent f70978c034
commit 1e3df9c943
3 changed files with 39 additions and 23 deletions

View File

@ -366,11 +366,7 @@ public:
auto [inputShape, outputShape] = auto [inputShape, outputShape] =
getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape()); getInputAndOutputShape(op.getSelf(), outputSizeTorchInt, resultType.getShape());
llvm::errs() << "inputShape: ";
llvm::interleaveComma(inputShape, llvm::errs());
llvm::errs() << " \noutput shape: ";
llvm::interleaveComma(outputShape, llvm::errs());
// Currently, we only handle the cases where each dimension is either // Currently, we only handle the cases where each dimension is either
// being expanded or collapsed. We do not handle cases where it's neither // being expanded or collapsed. We do not handle cases where it's neither
@ -395,10 +391,6 @@ public:
m_TorchTensorSizeInt(op.getSelf(), &inputDim))) { m_TorchTensorSizeInt(op.getSelf(), &inputDim))) {
unchangedDims.push_back(std::make_pair(inputDim, outputDim)); unchangedDims.push_back(std::make_pair(inputDim, outputDim));
llvm::errs() << "inputDim: ";
llvm::errs() << inputDim;
llvm::errs() << " \noutput Dim: ";
llvm::errs() << outputDim;
} }
} }
// Mark the end of the input/output shapes // Mark the end of the input/output shapes
@ -440,10 +432,10 @@ public:
// Used for ensuring that we don't have an ambiguous expansion // Used for ensuring that we don't have an ambiguous expansion
bool assumedDynamicDimNotSplit = false; bool assumedDynamicDimNotSplit = false;
while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) { while (inputDim < nextUnchangedInput && outputDim < nextUnchangedOutput) {
auto inputShapeSlice = // 1, ?, 4096 auto inputShapeSlice = // ?, 4096
MutableArrayRef<int64_t>(inputShape) MutableArrayRef<int64_t>(inputShape)
.slice(inputDim, nextUnchangedInput - inputDim); .slice(inputDim, nextUnchangedInput - inputDim);
auto outputShapeSlice = // ?, 4096 auto outputShapeSlice = //1, ?, 4096
MutableArrayRef<int64_t>(outputShape) MutableArrayRef<int64_t>(outputShape)
.slice(outputDim, nextUnchangedOutput - outputDim); .slice(outputDim, nextUnchangedOutput - outputDim);
SmallVector<int64_t> inputSliceIndices; SmallVector<int64_t> inputSliceIndices;
@ -483,13 +475,21 @@ public:
/// known to have the same number of elements. /// known to have the same number of elements.
} else if (inputShapeSlice[0] == kUnknownSize) { } else if (inputShapeSlice[0] == kUnknownSize) {
// If the input is dynamic, assume it is not split // If the input is dynamic, assume it is not split
checkDimEqualHelper(rewriter, loc, inputSize[inputDim], // Elide any unit dims in the output and checkDimEqual on dynamic dims
outputSizeInt[outputDim]);
int64_t idx = 0;
while (idx < outputShapeSlice.size() && outputShapeSlice[idx] == 1) {
outputSliceIndices.push_back(idx++);
}
//checkDimEqualHelper(rewriter, loc, inputSize[inputDim],
// outputSizeInt[idx]);
// If output dimension is not dynamic, improve static information of // If output dimension is not dynamic, improve static information of
// input // input
inputShape[inputDim] = outputShape[outputDim]; inputShape[inputDim] = outputShape[idx];
inputSliceIndices.push_back(0); inputSliceIndices.push_back(0);
outputSliceIndices.push_back(0); outputSliceIndices.push_back(idx);
assumedDynamicDimNotSplit = true; assumedDynamicDimNotSplit = true;
} else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) { } else if (inputShapeSlice[0] == 1 && outputShapeSlice[0] == kUnknownSize) {
int64_t idx = 0; int64_t idx = 0;

View File

@ -1543,6 +1543,8 @@ public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op, LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
llvm::errs() << "dbg decompose fill_scalar\n";
Location loc = op.getLoc(); Location loc = op.getLoc();
auto resType = op.getType().cast<BaseTensorType>(); auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) { if (!resType.hasDtype()) {
@ -4429,13 +4431,18 @@ public:
LogicalResult matchAndRewrite(AtenEmptyStridedOp op, LogicalResult matchAndRewrite(AtenEmptyStridedOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
SmallVector<int64_t> sizeListInts, strideListInts; SmallVector<int64_t> sizeListInts, strideListInts;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))) llvm::errs() << "dbg Try decompose empty strided\n";
return rewriter.notifyMatchFailure( //if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(sizeListInts))){
op, "all size list elements must be constant ints");
if (!matchPattern(op.getStride(), // llvm::errs() << "dbg all size list ele const ints\n";
m_TorchListOfConstantInts(strideListInts))) // return rewriter.notifyMatchFailure(
return rewriter.notifyMatchFailure( // op, "all size list elements must be constant ints");
op, "all stride list elements must be constant ints"); //}
//if (!matchPattern(op.getStride(),
// m_TorchListOfConstantInts(strideListInts))){
// llvm::errs() << "dbg all stride list ele const ints\n";
// return rewriter.notifyMatchFailure(
// op, "all stride list elements must be constant ints");}
// We only support the cases with default stride values. // We only support the cases with default stride values.
// For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1]) // For ex: aten.new_empty_strided(self, size=[2, 3, 4], stride=[12, 4, 1])
@ -4451,16 +4458,18 @@ public:
break; break;
} }
} }
if (!isDefaultStride) if (!isDefaultStride){
llvm::errs() << "dbg non default strides\n";
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "only default strides supported for new_empty_strided op"); op, "only default strides supported for new_empty_strided op");
}
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc()); Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>( rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(), op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal); op.getPinMemory(), /*memoryFormat=*/noneVal);
llvm::errs() << "dbg success\n";
return success(); return success();

View File

@ -227,12 +227,14 @@ public:
if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>()) if (!op->hasTrait<Torch::OpTrait::IsTrailingUnderscoreInplaceVariant>())
return rewriter.notifyMatchFailure(op, "is not trailing_ variant"); return rewriter.notifyMatchFailure(op, "is not trailing_ variant");
llvm::errs() << "dbg op name: " << op->getName().getStringRef() << "\n";
SmallVector<StringRef> fragments; SmallVector<StringRef> fragments;
llvm::SplitString(op->getName().getStringRef(), fragments, "."); llvm::SplitString(op->getName().getStringRef(), fragments, ".");
assert(fragments.size() >= 3 && fragments[2].endswith("_") && assert(fragments.size() >= 3 && fragments[2].endswith("_") &&
"IsTrailingUnderscoreInplaceVariant incorrectly applied"); "IsTrailingUnderscoreInplaceVariant incorrectly applied");
fragments[2] = fragments[2].drop_back(); fragments[2] = fragments[2].drop_back();
std::string noUnderscoreName = llvm::join(fragments, "."); std::string noUnderscoreName = llvm::join(fragments, ".");
llvm::errs() << "dbg fragments good\n";
OperationState state(op->getLoc(), noUnderscoreName); OperationState state(op->getLoc(), noUnderscoreName);
state.addTypes(op->getResultTypes()); state.addTypes(op->getResultTypes());
@ -241,14 +243,19 @@ public:
// Note: No successors or regions. Torch JIT operators don't have any. // Note: No successors or regions. Torch JIT operators don't have any.
assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 && assert(op->getNumRegions() == 0 && op->getNumSuccessors() == 0 &&
"Torch JIT operators shouldn't have regions or successors"); "Torch JIT operators shouldn't have regions or successors");
llvm::errs() << "dbg good regions, good successors\n";
Operation *newOp = rewriter.create(state); Operation *newOp = rewriter.create(state);
auto tensor = auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0)); rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
llvm::errs() << "dbg create copy_to_value_tensor\n";
createOverwriteTensorContents(rewriter, op->getLoc(), tensor, createOverwriteTensorContents(rewriter, op->getLoc(), tensor,
op->getOperand(0)); op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0));
llvm::errs() << "dbg success\n";
return success(); return success();
} }
}; };