mirror of https://github.com/llvm/torch-mlir
[onnx] Handle `torch.aten` for inner product case (#3634)
The following case was failing to lower for einsum. This fixes up the inner product issue.pull/3655/merge
parent
6cf139687d
commit
f9766c89f6
|
@ -292,7 +292,7 @@ static bool parseEquation(const std::string &equation,
|
|||
inputToken.clear();
|
||||
} else if ((index < (equation.size() - 1)) &&
|
||||
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
||||
inputTokens.push_back(inputToken);
|
||||
inputTokens.push_back(std::move(inputToken));
|
||||
inputToken.clear();
|
||||
currentVariable = kIsResult;
|
||||
index++;
|
||||
|
@ -301,6 +301,11 @@ static bool parseEquation(const std::string &equation,
|
|||
}
|
||||
index++;
|
||||
}
|
||||
|
||||
if (!inputToken.empty() && currentVariable == kIsInput) {
|
||||
inputTokens.push_back(std::move(inputToken));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -378,7 +383,9 @@ diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter,
|
|||
|
||||
std::string resultString(resultTokens.begin(), resultTokens.end());
|
||||
|
||||
equation = llvm::join(inputStrings, ",") + "->" + resultString;
|
||||
equation = llvm::join(inputStrings, ",");
|
||||
if (!resultString.empty())
|
||||
equation = equation + "->" + resultString;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -389,7 +396,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
int64_t contractingDimsLength,
|
||||
int64_t otherDimsLength,
|
||||
int64_t reduceDimsLength, bool isLhs) {
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||
reduceDimsLength;
|
||||
SmallVector<Value> inputShapeTensor;
|
||||
|
@ -422,12 +429,22 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
if (isLhs)
|
||||
appendDims(contractingDimsLength);
|
||||
|
||||
SmallVector<int64_t> resultShape;
|
||||
for (auto value : outShapeTensor) {
|
||||
int64_t v;
|
||||
if (matchPattern(value, m_TorchConstantInt(&v))) {
|
||||
resultShape.push_back(v);
|
||||
continue;
|
||||
}
|
||||
resultShape.push_back(Torch::kUnknownSize);
|
||||
}
|
||||
|
||||
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||
outShapeTensor);
|
||||
|
||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||
inputType.getOptionalDtype());
|
||||
auto outType =
|
||||
inputType.getWithSizesAndDtype(resultShape, inputType.getOptionalDtype());
|
||||
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
||||
outShapeValue);
|
||||
}
|
||||
|
@ -508,17 +525,19 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
SmallVector<char> &contractingDims,
|
||||
SmallVector<char> &otherDims,
|
||||
SmallVector<char> &reduceDims, bool isLhs) {
|
||||
auto inputType = cast<BaseTensorType>(input.getType());
|
||||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||
dimTokenMap[dimTokens[idx]] = idx;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape;
|
||||
SmallVector<Value> permuteVec;
|
||||
auto appendDims = [&](SmallVector<char> dimTokens) {
|
||||
for (auto d : dimTokens) {
|
||||
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
||||
permuteShape.push_back(inputType.getSizes()[dimTokenMap[d]]);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -533,7 +552,8 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||
permuteVec);
|
||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||
|
||||
auto outType = inputType.getWithSizesAndDtype(permuteShape,
|
||||
inputType.getOptionalDtype());
|
||||
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
||||
}
|
||||
|
@ -544,8 +564,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
Value &result,
|
||||
SmallVector<char> &resultTokens,
|
||||
SmallVector<char> &finalResultTokens) {
|
||||
auto lhsType = cast<BaseTensorType>(lhs.getType());
|
||||
auto rhsType = cast<BaseTensorType>(rhs.getType());
|
||||
auto lhsType = cast<ValueTensorType>(lhs.getType());
|
||||
auto rhsType = cast<ValueTensorType>(rhs.getType());
|
||||
|
||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||
: rhsType.getOptionalDtype();
|
||||
|
@ -618,14 +638,18 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
contractingDims.size(), rhsOtherDims.size(),
|
||||
rhsReduceDims.size(), false);
|
||||
|
||||
// perform matmul
|
||||
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
||||
lhsType = cast<ValueTensorType>(lhs.getType());
|
||||
rhsType = cast<ValueTensorType>(rhs.getType());
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
outShape.push_back(lhsType.getSizes()[0]);
|
||||
outShape.push_back(lhsType.getSizes()[1]);
|
||||
outShape.push_back(rhsType.getSizes()[2]);
|
||||
|
||||
// perform matmul
|
||||
auto outType = lhsType.getWithSizesAndDtype(outShape, outputDType);
|
||||
|
||||
if (contractingDims.size() != 0) {
|
||||
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||
} else {
|
||||
result = rewriter.create<Torch::AtenMulTensorOp>(loc, outType, lhs, rhs);
|
||||
}
|
||||
|
||||
// generate ideal result dims.
|
||||
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
||||
|
@ -640,11 +664,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> resultShape;
|
||||
for (auto value : outShapeTensors) {
|
||||
int64_t v;
|
||||
if (matchPattern(value, m_TorchConstantInt(&v))) {
|
||||
resultShape.push_back(v);
|
||||
continue;
|
||||
}
|
||||
resultShape.push_back(Torch::kUnknownSize);
|
||||
}
|
||||
|
||||
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
||||
outShapeTensors);
|
||||
result = rewriter.create<Torch::AtenReshapeOp>(
|
||||
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
|
||||
loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), result,
|
||||
outResultShape);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -100,3 +100,31 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v
|
|||
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
|
||||
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: test_einsum_inner_prod
|
||||
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||
// CHECK: %[[INT5:.+]] = torch.constant.int 5
|
||||
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
|
||||
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||
// CHECK: %[[RHS_PERM:.+]] = torch.aten.permute %arg1, %[[RHS_LIST]]
|
||||
// CHECK: %[[LHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]], %[[INT5]]
|
||||
// CHECK: %[[LHS_VIEW:.+]] = torch.aten.view %[[LHS_PERM]], %[[LHS_SHP]]
|
||||
// CHECK: %[[RHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT5]], %[[INT1]]
|
||||
// CHECK: %[[RHS_VIEW:.+]] = torch.aten.view %[[RHS_PERM]], %[[RHS_SHP]]
|
||||
// CHECK: %[[BMM:.+]] = torch.aten.bmm %[[LHS_VIEW]], %[[RHS_VIEW]]
|
||||
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[OUT_VIEW:.+]] = torch.aten.view %[[BMM]], %[[EMPTY]]
|
||||
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
// CHECK: %[[OUT_PERM:.+]] = torch.aten.permute %[[OUT_VIEW]], %[[EMPTY]]
|
||||
// CHECK: return %[[OUT_PERM]]
|
||||
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>) -> !torch.list<vtensor>
|
||||
%str = torch.constant.str "i,i"
|
||||
%none_0 = torch.constant.none
|
||||
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
|
||||
return %1 : !torch.vtensor<[],f64>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue