[onnx] Handle `torch.aten` for inner product case (#3634)

The following case was failing to lower for einsum. This fixes up the
inner product issue.
pull/3655/merge
Rob Suderman 2024-08-24 11:41:25 -07:00 committed by GitHub
parent 6cf139687d
commit f9766c89f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 17 deletions

View File

@ -292,7 +292,7 @@ static bool parseEquation(const std::string &equation,
inputToken.clear(); inputToken.clear();
} else if ((index < (equation.size() - 1)) && } else if ((index < (equation.size() - 1)) &&
(equation.substr(index, 2).find("->") != std::string::npos)) { (equation.substr(index, 2).find("->") != std::string::npos)) {
inputTokens.push_back(inputToken); inputTokens.push_back(std::move(inputToken));
inputToken.clear(); inputToken.clear();
currentVariable = kIsResult; currentVariable = kIsResult;
index++; index++;
@ -301,6 +301,11 @@ static bool parseEquation(const std::string &equation,
} }
index++; index++;
} }
if (!inputToken.empty() && currentVariable == kIsInput) {
inputTokens.push_back(std::move(inputToken));
}
return true; return true;
} }
@ -378,7 +383,9 @@ diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter,
std::string resultString(resultTokens.begin(), resultTokens.end()); std::string resultString(resultTokens.begin(), resultTokens.end());
equation = llvm::join(inputStrings, ",") + "->" + resultString; equation = llvm::join(inputStrings, ",");
if (!resultString.empty())
equation = equation + "->" + resultString;
return true; return true;
} }
@ -389,7 +396,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
int64_t contractingDimsLength, int64_t contractingDimsLength,
int64_t otherDimsLength, int64_t otherDimsLength,
int64_t reduceDimsLength, bool isLhs) { int64_t reduceDimsLength, bool isLhs) {
auto inputType = cast<BaseTensorType>(input.getType()); auto inputType = cast<ValueTensorType>(input.getType());
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
reduceDimsLength; reduceDimsLength;
SmallVector<Value> inputShapeTensor; SmallVector<Value> inputShapeTensor;
@ -422,12 +429,22 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
if (isLhs) if (isLhs)
appendDims(contractingDimsLength); appendDims(contractingDimsLength);
SmallVector<int64_t> resultShape;
for (auto value : outShapeTensor) {
int64_t v;
if (matchPattern(value, m_TorchConstantInt(&v))) {
resultShape.push_back(v);
continue;
}
resultShape.push_back(Torch::kUnknownSize);
}
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>( auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
outShapeTensor); outShapeTensor);
auto outType = inputType.getWithSizesAndDtype(std::nullopt, auto outType =
inputType.getOptionalDtype()); inputType.getWithSizesAndDtype(resultShape, inputType.getOptionalDtype());
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input, return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
outShapeValue); outShapeValue);
} }
@ -508,17 +525,19 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
SmallVector<char> &contractingDims, SmallVector<char> &contractingDims,
SmallVector<char> &otherDims, SmallVector<char> &otherDims,
SmallVector<char> &reduceDims, bool isLhs) { SmallVector<char> &reduceDims, bool isLhs) {
auto inputType = cast<BaseTensorType>(input.getType()); auto inputType = cast<ValueTensorType>(input.getType());
llvm::SmallDenseMap<char, int64_t> dimTokenMap; llvm::SmallDenseMap<char, int64_t> dimTokenMap;
for (size_t idx = 0; idx < dimTokens.size(); ++idx) { for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
dimTokenMap[dimTokens[idx]] = idx; dimTokenMap[dimTokens[idx]] = idx;
} }
SmallVector<int64_t> permuteShape;
SmallVector<Value> permuteVec; SmallVector<Value> permuteVec;
auto appendDims = [&](SmallVector<char> dimTokens) { auto appendDims = [&](SmallVector<char> dimTokens) {
for (auto d : dimTokens) { for (auto d : dimTokens) {
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>( permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
permuteShape.push_back(inputType.getSizes()[dimTokenMap[d]]);
} }
}; };
@ -533,7 +552,8 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
Value dstDims = rewriter.create<Torch::PrimListConstructOp>( Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
permuteVec); permuteVec);
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
auto outType = inputType.getWithSizesAndDtype(permuteShape,
inputType.getOptionalDtype()); inputType.getOptionalDtype());
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims); return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
} }
@ -544,8 +564,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
Value &result, Value &result,
SmallVector<char> &resultTokens, SmallVector<char> &resultTokens,
SmallVector<char> &finalResultTokens) { SmallVector<char> &finalResultTokens) {
auto lhsType = cast<BaseTensorType>(lhs.getType()); auto lhsType = cast<ValueTensorType>(lhs.getType());
auto rhsType = cast<BaseTensorType>(rhs.getType()); auto rhsType = cast<ValueTensorType>(rhs.getType());
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
: rhsType.getOptionalDtype(); : rhsType.getOptionalDtype();
@ -618,14 +638,18 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
contractingDims.size(), rhsOtherDims.size(), contractingDims.size(), rhsOtherDims.size(),
rhsReduceDims.size(), false); rhsReduceDims.size(), false);
// perform matmul lhsType = cast<ValueTensorType>(lhs.getType());
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType); rhsType = cast<ValueTensorType>(rhs.getType());
SmallVector<int64_t> outShape;
outShape.push_back(lhsType.getSizes()[0]);
outShape.push_back(lhsType.getSizes()[1]);
outShape.push_back(rhsType.getSizes()[2]);
// perform matmul
auto outType = lhsType.getWithSizesAndDtype(outShape, outputDType);
if (contractingDims.size() != 0) {
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs); result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
} else {
result = rewriter.create<Torch::AtenMulTensorOp>(loc, outType, lhs, rhs);
}
// generate ideal result dims. // generate ideal result dims.
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
@ -640,11 +664,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
outShapeTensors.emplace_back(outDimShapeMap[d]); outShapeTensors.emplace_back(outDimShapeMap[d]);
} }
SmallVector<int64_t> resultShape;
for (auto value : outShapeTensors) {
int64_t v;
if (matchPattern(value, m_TorchConstantInt(&v))) {
resultShape.push_back(v);
continue;
}
resultShape.push_back(Torch::kUnknownSize);
}
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>( auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
outShapeTensors); outShapeTensors);
result = rewriter.create<Torch::AtenReshapeOp>( result = rewriter.create<Torch::AtenReshapeOp>(
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result, loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), result,
outResultShape); outResultShape);
return success(); return success();
} }

View File

@ -100,3 +100,31 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1> %0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
return %0#0 : !torch.vtensor<[?,?,?,?],f32> return %0#0 : !torch.vtensor<[?,?,?,?],f32>
} }
// -----
// CHECK-LABEL: test_einsum_inner_prod
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
// CHECK: %[[INT5:.+]] = torch.constant.int 5
// CHECK: %[[INT1:.+]] = torch.constant.int 1
// CHECK: %[[INT0:.+]] = torch.constant.int 0
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
// CHECK: %[[RHS_PERM:.+]] = torch.aten.permute %arg1, %[[RHS_LIST]]
// CHECK: %[[LHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]], %[[INT5]]
// CHECK: %[[LHS_VIEW:.+]] = torch.aten.view %[[LHS_PERM]], %[[LHS_SHP]]
// CHECK: %[[RHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT5]], %[[INT1]]
// CHECK: %[[RHS_VIEW:.+]] = torch.aten.view %[[RHS_PERM]], %[[RHS_SHP]]
// CHECK: %[[BMM:.+]] = torch.aten.bmm %[[LHS_VIEW]], %[[RHS_VIEW]]
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[OUT_VIEW:.+]] = torch.aten.view %[[BMM]], %[[EMPTY]]
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
// CHECK: %[[OUT_PERM:.+]] = torch.aten.permute %[[OUT_VIEW]], %[[EMPTY]]
// CHECK: return %[[OUT_PERM]]
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>) -> !torch.list<vtensor>
%str = torch.constant.str "i,i"
%none_0 = torch.constant.none
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
return %1 : !torch.vtensor<[],f64>
}