mirror of https://github.com/llvm/torch-mlir
[onnx] Handle `torch.aten` for inner product case (#3634)
The following case was failing to lower for einsum. This fixes up the inner product issue.pull/3655/merge
parent
6cf139687d
commit
f9766c89f6
|
@ -292,7 +292,7 @@ static bool parseEquation(const std::string &equation,
|
||||||
inputToken.clear();
|
inputToken.clear();
|
||||||
} else if ((index < (equation.size() - 1)) &&
|
} else if ((index < (equation.size() - 1)) &&
|
||||||
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
||||||
inputTokens.push_back(inputToken);
|
inputTokens.push_back(std::move(inputToken));
|
||||||
inputToken.clear();
|
inputToken.clear();
|
||||||
currentVariable = kIsResult;
|
currentVariable = kIsResult;
|
||||||
index++;
|
index++;
|
||||||
|
@ -301,6 +301,11 @@ static bool parseEquation(const std::string &equation,
|
||||||
}
|
}
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!inputToken.empty() && currentVariable == kIsInput) {
|
||||||
|
inputTokens.push_back(std::move(inputToken));
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -378,7 +383,9 @@ diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter,
|
||||||
|
|
||||||
std::string resultString(resultTokens.begin(), resultTokens.end());
|
std::string resultString(resultTokens.begin(), resultTokens.end());
|
||||||
|
|
||||||
equation = llvm::join(inputStrings, ",") + "->" + resultString;
|
equation = llvm::join(inputStrings, ",");
|
||||||
|
if (!resultString.empty())
|
||||||
|
equation = equation + "->" + resultString;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -389,7 +396,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
int64_t contractingDimsLength,
|
int64_t contractingDimsLength,
|
||||||
int64_t otherDimsLength,
|
int64_t otherDimsLength,
|
||||||
int64_t reduceDimsLength, bool isLhs) {
|
int64_t reduceDimsLength, bool isLhs) {
|
||||||
auto inputType = cast<BaseTensorType>(input.getType());
|
auto inputType = cast<ValueTensorType>(input.getType());
|
||||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||||
reduceDimsLength;
|
reduceDimsLength;
|
||||||
SmallVector<Value> inputShapeTensor;
|
SmallVector<Value> inputShapeTensor;
|
||||||
|
@ -422,12 +429,22 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
if (isLhs)
|
if (isLhs)
|
||||||
appendDims(contractingDimsLength);
|
appendDims(contractingDimsLength);
|
||||||
|
|
||||||
|
SmallVector<int64_t> resultShape;
|
||||||
|
for (auto value : outShapeTensor) {
|
||||||
|
int64_t v;
|
||||||
|
if (matchPattern(value, m_TorchConstantInt(&v))) {
|
||||||
|
resultShape.push_back(v);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
resultShape.push_back(Torch::kUnknownSize);
|
||||||
|
}
|
||||||
|
|
||||||
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||||
outShapeTensor);
|
outShapeTensor);
|
||||||
|
|
||||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
auto outType =
|
||||||
inputType.getOptionalDtype());
|
inputType.getWithSizesAndDtype(resultShape, inputType.getOptionalDtype());
|
||||||
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
||||||
outShapeValue);
|
outShapeValue);
|
||||||
}
|
}
|
||||||
|
@ -508,17 +525,19 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
SmallVector<char> &contractingDims,
|
SmallVector<char> &contractingDims,
|
||||||
SmallVector<char> &otherDims,
|
SmallVector<char> &otherDims,
|
||||||
SmallVector<char> &reduceDims, bool isLhs) {
|
SmallVector<char> &reduceDims, bool isLhs) {
|
||||||
auto inputType = cast<BaseTensorType>(input.getType());
|
auto inputType = cast<ValueTensorType>(input.getType());
|
||||||
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||||
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||||
dimTokenMap[dimTokens[idx]] = idx;
|
dimTokenMap[dimTokens[idx]] = idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> permuteShape;
|
||||||
SmallVector<Value> permuteVec;
|
SmallVector<Value> permuteVec;
|
||||||
auto appendDims = [&](SmallVector<char> dimTokens) {
|
auto appendDims = [&](SmallVector<char> dimTokens) {
|
||||||
for (auto d : dimTokens) {
|
for (auto d : dimTokens) {
|
||||||
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
||||||
|
permuteShape.push_back(inputType.getSizes()[dimTokenMap[d]]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -533,7 +552,8 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||||
permuteVec);
|
permuteVec);
|
||||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
|
||||||
|
auto outType = inputType.getWithSizesAndDtype(permuteShape,
|
||||||
inputType.getOptionalDtype());
|
inputType.getOptionalDtype());
|
||||||
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
||||||
}
|
}
|
||||||
|
@ -544,8 +564,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
Value &result,
|
Value &result,
|
||||||
SmallVector<char> &resultTokens,
|
SmallVector<char> &resultTokens,
|
||||||
SmallVector<char> &finalResultTokens) {
|
SmallVector<char> &finalResultTokens) {
|
||||||
auto lhsType = cast<BaseTensorType>(lhs.getType());
|
auto lhsType = cast<ValueTensorType>(lhs.getType());
|
||||||
auto rhsType = cast<BaseTensorType>(rhs.getType());
|
auto rhsType = cast<ValueTensorType>(rhs.getType());
|
||||||
|
|
||||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||||
: rhsType.getOptionalDtype();
|
: rhsType.getOptionalDtype();
|
||||||
|
@ -618,14 +638,18 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
contractingDims.size(), rhsOtherDims.size(),
|
contractingDims.size(), rhsOtherDims.size(),
|
||||||
rhsReduceDims.size(), false);
|
rhsReduceDims.size(), false);
|
||||||
|
|
||||||
// perform matmul
|
lhsType = cast<ValueTensorType>(lhs.getType());
|
||||||
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
rhsType = cast<ValueTensorType>(rhs.getType());
|
||||||
|
|
||||||
|
SmallVector<int64_t> outShape;
|
||||||
|
outShape.push_back(lhsType.getSizes()[0]);
|
||||||
|
outShape.push_back(lhsType.getSizes()[1]);
|
||||||
|
outShape.push_back(rhsType.getSizes()[2]);
|
||||||
|
|
||||||
|
// perform matmul
|
||||||
|
auto outType = lhsType.getWithSizesAndDtype(outShape, outputDType);
|
||||||
|
|
||||||
if (contractingDims.size() != 0) {
|
|
||||||
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||||
} else {
|
|
||||||
result = rewriter.create<Torch::AtenMulTensorOp>(loc, outType, lhs, rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate ideal result dims.
|
// generate ideal result dims.
|
||||||
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
||||||
|
@ -640,11 +664,21 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> resultShape;
|
||||||
|
for (auto value : outShapeTensors) {
|
||||||
|
int64_t v;
|
||||||
|
if (matchPattern(value, m_TorchConstantInt(&v))) {
|
||||||
|
resultShape.push_back(v);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
resultShape.push_back(Torch::kUnknownSize);
|
||||||
|
}
|
||||||
|
|
||||||
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
||||||
outShapeTensors);
|
outShapeTensors);
|
||||||
result = rewriter.create<Torch::AtenReshapeOp>(
|
result = rewriter.create<Torch::AtenReshapeOp>(
|
||||||
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
|
loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), result,
|
||||||
outResultShape);
|
outResultShape);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -100,3 +100,31 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v
|
||||||
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
|
%0:2 = torch.aten.fake_quantize_per_channel_affine_cachemask %arg0, %arg1, %arg2, %int0, %int-128, %int127 : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?],f32>, !torch.vtensor<[?],si32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],i1>
|
||||||
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
|
return %0#0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: test_einsum_inner_prod
|
||||||
|
func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} {
|
||||||
|
// CHECK: %[[INT5:.+]] = torch.constant.int 5
|
||||||
|
// CHECK: %[[INT1:.+]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[INT0:.+]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||||
|
// CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]]
|
||||||
|
// CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]]
|
||||||
|
// CHECK: %[[RHS_PERM:.+]] = torch.aten.permute %arg1, %[[RHS_LIST]]
|
||||||
|
// CHECK: %[[LHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]], %[[INT5]]
|
||||||
|
// CHECK: %[[LHS_VIEW:.+]] = torch.aten.view %[[LHS_PERM]], %[[LHS_SHP]]
|
||||||
|
// CHECK: %[[RHS_SHP:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT5]], %[[INT1]]
|
||||||
|
// CHECK: %[[RHS_VIEW:.+]] = torch.aten.view %[[RHS_PERM]], %[[RHS_SHP]]
|
||||||
|
// CHECK: %[[BMM:.+]] = torch.aten.bmm %[[LHS_VIEW]], %[[RHS_VIEW]]
|
||||||
|
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUT_VIEW:.+]] = torch.aten.view %[[BMM]], %[[EMPTY]]
|
||||||
|
// CHECK: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUT_PERM:.+]] = torch.aten.permute %[[OUT_VIEW]], %[[EMPTY]]
|
||||||
|
// CHECK: return %[[OUT_PERM]]
|
||||||
|
%0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>) -> !torch.list<vtensor>
|
||||||
|
%str = torch.constant.str "i,i"
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%1 = torch.aten.einsum %str, %0, %none_0 : !torch.str, !torch.list<vtensor>, !torch.none -> !torch.vtensor<[],f64>
|
||||||
|
return %1 : !torch.vtensor<[],f64>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue