mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Support Einsum Op (#2230)
As title, support torch.aten.einsum op Right now only support Static Shape, because of the known issue, the fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154 Co-authored-by: Jiawei Wu [wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com)pull/2628/head snapshot-20231210.1048
parent
07c3e11f56
commit
96fcde4d77
|
@ -8447,6 +8447,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_StringType:$equation,
|
||||||
|
AnyTorchListOfTensorType:$tensors,
|
||||||
|
AnyTorchOptionalListOfTorchIntType:$path
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenEinsumOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
|
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
" return %5 : !torch.int\n"
|
" return %5 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
|
||||||
|
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||||
|
" %2 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||||
|
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %3 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||||
|
" torch.prim.Loop %4, %true, init() {\n"
|
||||||
|
" ^bb0(%arg3: !torch.int):\n"
|
||||||
|
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
|
||||||
|
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
|
||||||
|
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.Loop.condition %true, iter()\n"
|
||||||
|
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||||
|
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %5 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %int4 = torch.constant.int 4\n"
|
" %int4 = torch.constant.int 4\n"
|
||||||
" return %int4 : !torch.int\n"
|
" return %int4 : !torch.int\n"
|
||||||
|
|
|
@ -187,6 +187,358 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
|
||||||
return dimsOrder;
|
return dimsOrder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool parseEquation(const std::string &equation,
|
||||||
|
SmallVector<SmallVector<char>> &inputTokens,
|
||||||
|
SmallVector<char> &resultTokens) {
|
||||||
|
SmallVector<char> inputToken;
|
||||||
|
size_t index = 0;
|
||||||
|
enum EquationVariable { kIsInput, kIsResult };
|
||||||
|
EquationVariable currentVariable = kIsInput;
|
||||||
|
while (index < equation.size()) {
|
||||||
|
if (std::isalpha(equation[index])) {
|
||||||
|
if (currentVariable == kIsInput) {
|
||||||
|
inputToken.push_back(equation[index]);
|
||||||
|
} else {
|
||||||
|
resultTokens.push_back(equation[index]);
|
||||||
|
}
|
||||||
|
} else if (equation[index] == ',') {
|
||||||
|
inputTokens.push_back(inputToken);
|
||||||
|
inputToken.clear();
|
||||||
|
} else if ((index < (equation.size() - 1)) &&
|
||||||
|
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
||||||
|
inputTokens.push_back(inputToken);
|
||||||
|
inputToken.clear();
|
||||||
|
currentVariable = kIsResult;
|
||||||
|
index++;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
|
||||||
|
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
||||||
|
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value input, int64_t batchDimsLength,
|
||||||
|
int64_t contractingDimsLength,
|
||||||
|
int64_t otherDimsLength,
|
||||||
|
int64_t reduceDimsLength, bool isLhs) {
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||||
|
reduceDimsLength;
|
||||||
|
SmallVector<Value> inputShapeTensor;
|
||||||
|
for (auto i = 0; i < inputRank; ++i) {
|
||||||
|
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
|
||||||
|
loc, input,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||||
|
rewriter.getI64IntegerAttr(i))));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> outShapeTensor;
|
||||||
|
Value constOne =
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
auto dimOffset = 0;
|
||||||
|
|
||||||
|
auto appendDims = [&](int64_t dimLength) {
|
||||||
|
Value prod = constOne;
|
||||||
|
for (auto i = 0; i < dimLength; ++i) {
|
||||||
|
prod = rewriter.create<AtenMulIntOp>(loc, prod,
|
||||||
|
inputShapeTensor[i + dimOffset]);
|
||||||
|
}
|
||||||
|
outShapeTensor.emplace_back(prod);
|
||||||
|
dimOffset += dimLength;
|
||||||
|
};
|
||||||
|
|
||||||
|
appendDims(batchDimsLength);
|
||||||
|
if (!isLhs)
|
||||||
|
appendDims(contractingDimsLength);
|
||||||
|
appendDims(otherDimsLength + reduceDimsLength);
|
||||||
|
if (isLhs)
|
||||||
|
appendDims(contractingDimsLength);
|
||||||
|
|
||||||
|
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||||
|
outShapeTensor);
|
||||||
|
|
||||||
|
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||||
|
inputType.getOptionalDtype());
|
||||||
|
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
||||||
|
outShapeValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
// classify every dim token into different categories. Note that although we
|
||||||
|
// parse out reduce dims, we delay their execution until
|
||||||
|
// `performLastPermuteAndReduce`.
|
||||||
|
static void parseDimTokens(
|
||||||
|
SmallVector<char> &lhsTokens, SmallVector<char> &rhsTokens,
|
||||||
|
SmallVector<char> &finalResultTokens, SmallVector<char> &contractingDims,
|
||||||
|
SmallVector<char> &lhsReduceDims, SmallVector<char> &rhsReduceDims,
|
||||||
|
SmallVector<char> &batchingDims, SmallVector<char> &lhsOtherDims,
|
||||||
|
SmallVector<char> &rhsOtherDims) {
|
||||||
|
llvm::SmallDenseSet<char> lhsTokenSet(lhsTokens.begin(), lhsTokens.end());
|
||||||
|
llvm::SmallDenseSet<char> rhsTokenSet(rhsTokens.begin(), rhsTokens.end());
|
||||||
|
llvm::SmallDenseSet<char> finalResultTokenSet(finalResultTokens.begin(),
|
||||||
|
finalResultTokens.end());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < lhsTokens.size(); ++i) {
|
||||||
|
bool rhsContains = rhsTokenSet.contains(lhsTokens[i]);
|
||||||
|
bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]);
|
||||||
|
// batching dim
|
||||||
|
if (rhsContains && finalResultConatins) {
|
||||||
|
batchingDims.push_back(lhsTokens[i]);
|
||||||
|
// reduce dim of lhs
|
||||||
|
} else if (!rhsContains && !finalResultConatins) {
|
||||||
|
lhsReduceDims.push_back(lhsTokens[i]);
|
||||||
|
// other dim of lhs
|
||||||
|
} else if (finalResultConatins) {
|
||||||
|
lhsOtherDims.push_back(lhsTokens[i]);
|
||||||
|
// contracting dim of lhs
|
||||||
|
} else if (rhsContains) {
|
||||||
|
contractingDims.push_back(lhsTokens[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < rhsTokens.size(); ++i) {
|
||||||
|
bool lhsContains = lhsTokenSet.contains(rhsTokens[i]);
|
||||||
|
bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]);
|
||||||
|
// batching dim
|
||||||
|
if (lhsContains && finalResultConatins) {
|
||||||
|
// reduce dim of rhs
|
||||||
|
} else if (!lhsContains && !finalResultConatins) {
|
||||||
|
rhsReduceDims.push_back(rhsTokens[i]);
|
||||||
|
// other dim of rhs
|
||||||
|
} else if (finalResultConatins) {
|
||||||
|
rhsOtherDims.push_back(rhsTokens[i]);
|
||||||
|
// contracting dim of rhs
|
||||||
|
} else if (lhsContains) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void generateIdealReusltDimTokens(SmallVector<char> &batchingDims,
|
||||||
|
SmallVector<char> &lhsOtherDims,
|
||||||
|
SmallVector<char> &rhsOtherDims,
|
||||||
|
SmallVector<char> &lhsReduceDims,
|
||||||
|
SmallVector<char> &rhsReduceDims,
|
||||||
|
SmallVector<char> &resultTokens) {
|
||||||
|
// generate ideal result dims, i.e.,
|
||||||
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
||||||
|
// *rhsReduceDims]
|
||||||
|
resultTokens.insert(resultTokens.end(), batchingDims.begin(),
|
||||||
|
batchingDims.end());
|
||||||
|
resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(),
|
||||||
|
lhsOtherDims.end());
|
||||||
|
resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(),
|
||||||
|
lhsReduceDims.end());
|
||||||
|
resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(),
|
||||||
|
rhsOtherDims.end());
|
||||||
|
resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(),
|
||||||
|
rhsReduceDims.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value input, SmallVector<char> &dimTokens,
|
||||||
|
SmallVector<char> &batchingDims,
|
||||||
|
SmallVector<char> &contractingDims,
|
||||||
|
SmallVector<char> &otherDims,
|
||||||
|
SmallVector<char> &reduceDims, bool isLhs) {
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||||
|
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||||
|
dimTokenMap[dimTokens[idx]] = idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> permuteVec;
|
||||||
|
auto appendDims = [&](SmallVector<char> dimTokens) {
|
||||||
|
for (auto d : dimTokens) {
|
||||||
|
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
appendDims(batchingDims);
|
||||||
|
if (!isLhs)
|
||||||
|
appendDims(contractingDims);
|
||||||
|
appendDims(otherDims);
|
||||||
|
appendDims(reduceDims);
|
||||||
|
if (isLhs)
|
||||||
|
appendDims(contractingDims);
|
||||||
|
|
||||||
|
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||||
|
permuteVec);
|
||||||
|
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||||
|
inputType.getOptionalDtype());
|
||||||
|
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
|
Value lhs, SmallVector<char> &lhsTokens,
|
||||||
|
Value rhs, SmallVector<char> &rhsTokens,
|
||||||
|
Value &result,
|
||||||
|
SmallVector<char> &resultTokens,
|
||||||
|
SmallVector<char> &finalResultTokens) {
|
||||||
|
auto lhsType = lhs.getType().cast<BaseTensorType>();
|
||||||
|
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
|
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||||
|
: rhsType.getOptionalDtype();
|
||||||
|
|
||||||
|
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
||||||
|
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
||||||
|
char d = lhsTokens[idx];
|
||||||
|
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||||
|
loc, lhs,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||||
|
rewriter.getI64IntegerAttr(idx)));
|
||||||
|
}
|
||||||
|
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
|
||||||
|
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
|
||||||
|
char d = rhsTokens[idx];
|
||||||
|
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||||
|
loc, rhs,
|
||||||
|
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||||
|
rewriter.getI64IntegerAttr(idx)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse batch, contracting, other, reduce dims of lhs and rhs
|
||||||
|
SmallVector<char> contractingDims;
|
||||||
|
SmallVector<char> lhsReduceDims;
|
||||||
|
SmallVector<char> rhsReduceDims;
|
||||||
|
SmallVector<char> lhsOtherDims;
|
||||||
|
SmallVector<char> rhsOtherDims;
|
||||||
|
SmallVector<char> batchingDims;
|
||||||
|
parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims,
|
||||||
|
lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims,
|
||||||
|
rhsOtherDims);
|
||||||
|
|
||||||
|
llvm::SmallDenseMap<char, Value> outDimShapeMap;
|
||||||
|
auto generateOutDimShapeMap = [&](SmallVector<char> &dims) {
|
||||||
|
for (auto d : dims) {
|
||||||
|
bool lhsContains = lhsDimShapeMap.count(d) > 0;
|
||||||
|
bool rhsContains = rhsDimShapeMap.count(d) > 0;
|
||||||
|
if (lhsContains && rhsContains) {
|
||||||
|
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
|
||||||
|
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
|
||||||
|
} else if (lhsContains) {
|
||||||
|
outDimShapeMap[d] = lhsDimShapeMap[d];
|
||||||
|
} else if (rhsContains) {
|
||||||
|
outDimShapeMap[d] = rhsDimShapeMap[d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
generateOutDimShapeMap(contractingDims);
|
||||||
|
generateOutDimShapeMap(batchingDims);
|
||||||
|
generateOutDimShapeMap(lhsReduceDims);
|
||||||
|
generateOutDimShapeMap(rhsReduceDims);
|
||||||
|
generateOutDimShapeMap(lhsOtherDims);
|
||||||
|
generateOutDimShapeMap(rhsOtherDims);
|
||||||
|
|
||||||
|
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
|
||||||
|
rhsOtherDims.size() == 0) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
loc, "Hadamard product is currently not supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
|
||||||
|
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
|
||||||
|
contractingDims, lhsOtherDims, lhsReduceDims,
|
||||||
|
true);
|
||||||
|
// shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims]
|
||||||
|
rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims,
|
||||||
|
contractingDims, rhsOtherDims, rhsReduceDims,
|
||||||
|
false);
|
||||||
|
// shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
||||||
|
lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(),
|
||||||
|
contractingDims.size(), lhsOtherDims.size(),
|
||||||
|
lhsReduceDims.size(), true);
|
||||||
|
// shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd]
|
||||||
|
rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(),
|
||||||
|
contractingDims.size(), rhsOtherDims.size(),
|
||||||
|
rhsReduceDims.size(), false);
|
||||||
|
|
||||||
|
// perform matmul
|
||||||
|
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
||||||
|
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||||
|
|
||||||
|
// generate ideal result dims.
|
||||||
|
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
||||||
|
lhsReduceDims, rhsReduceDims, resultTokens);
|
||||||
|
|
||||||
|
// reshape matmul result to ideal shape:
|
||||||
|
// [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] =>
|
||||||
|
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
||||||
|
// *rhsReduceDims]
|
||||||
|
SmallVector<Value> outShapeTensors;
|
||||||
|
for (char d : resultTokens) {
|
||||||
|
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
||||||
|
outShapeTensors);
|
||||||
|
result = rewriter.create<Torch::AtenReshapeOp>(
|
||||||
|
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
|
||||||
|
outResultShape);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
||||||
|
Location loc, Type outType,
|
||||||
|
Value input,
|
||||||
|
SmallVector<char> &inputTokens,
|
||||||
|
SmallVector<char> &outTokens) {
|
||||||
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
|
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
||||||
|
SmallVector<int64_t> sumDims;
|
||||||
|
llvm::SmallDenseMap<char, int64_t> inputDimToIdx;
|
||||||
|
int64_t idx = 0;
|
||||||
|
for (size_t i = 0; i < inputTokens.size(); ++i) {
|
||||||
|
char d = inputTokens[i];
|
||||||
|
if (!outTokenSet.contains(d)) {
|
||||||
|
sumDims.emplace_back(i);
|
||||||
|
} else {
|
||||||
|
inputDimToIdx[d] = idx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sumDims.size() > 0) {
|
||||||
|
SmallVector<Value> sumDimsTensor;
|
||||||
|
for (auto d : sumDims) {
|
||||||
|
sumDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(d)));
|
||||||
|
}
|
||||||
|
auto sumDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||||
|
sumDimsTensor);
|
||||||
|
auto falseValue = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
loc, rewriter.getBoolAttr(false));
|
||||||
|
auto noneValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||||
|
input = rewriter.create<Torch::AtenSumDimIntListOp>(
|
||||||
|
loc,
|
||||||
|
inputType.getWithSizesAndDtype(std::nullopt,
|
||||||
|
inputType.getOptionalDtype()),
|
||||||
|
input, sumDimsListValue, falseValue, noneValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> permuteDimsTensor;
|
||||||
|
for (auto d : outTokens) {
|
||||||
|
permuteDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(inputDimToIdx[d])));
|
||||||
|
}
|
||||||
|
auto permuteDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||||
|
permuteDimsTensor);
|
||||||
|
auto out = rewriter.create<Torch::AtenPermuteOp>(loc, outType, input,
|
||||||
|
permuteDimsListValue);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||||
/// number of dimensions across which the max needs to be computed.
|
/// number of dimensions across which the max needs to be computed.
|
||||||
|
@ -628,6 +980,78 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
|
||||||
|
// operation and permute operation. Currently, this pass doesn't support
|
||||||
|
// Hadamard product. The basic idea is that:
|
||||||
|
// Step 1: split the string equation to input/result tokens and find
|
||||||
|
// batchingDims, contractingDims, otherDims and reduceDims.
|
||||||
|
// Step 2: permute and reshape input tensors suitable
|
||||||
|
// for matmul operations.
|
||||||
|
// Step 3: use AtenMatmulOp to get the result.
|
||||||
|
// Step 4: iteratively execute step 2 & 3 until we get the final result.
|
||||||
|
// Step 5: perform remaining permute and reduce operations.
|
||||||
|
// notice: support static shape only
|
||||||
|
|
||||||
|
class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenEinsumOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
std::string equation;
|
||||||
|
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
|
||||||
|
}
|
||||||
|
SmallVector<char> resultTokens;
|
||||||
|
SmallVector<SmallVector<char>> inputTokens;
|
||||||
|
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Unexpected character in equations encountered");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> inputTensors;
|
||||||
|
if (!getListConstructElements(op.getTensors(), inputTensors)) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "input should comes from a PrimListConstructOp");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto allTensorHasSizes = [](Value tensor) {
|
||||||
|
auto type = tensor.getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (!type || !type.hasSizes())
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!llvm::all_of(inputTensors, allTensorHasSizes)) {
|
||||||
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
"all input tensors should have sizes");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<char> lhsTokens = inputTokens[0];
|
||||||
|
Value lhs = inputTensors[0];
|
||||||
|
Value result;
|
||||||
|
|
||||||
|
for (size_t i = 1; i < inputTensors.size(); ++i) {
|
||||||
|
auto rhs = inputTensors[i];
|
||||||
|
auto rhsTokens = inputTokens[i];
|
||||||
|
SmallVector<char> outTokens;
|
||||||
|
if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens,
|
||||||
|
result, outTokens, resultTokens))) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
lhs = result;
|
||||||
|
lhsTokens = outTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs,
|
||||||
|
lhsTokens, resultTokens);
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||||
// exp(x)/sum(exp(x)).
|
// exp(x)/sum(exp(x)).
|
||||||
// To avoid overflow we use the following decomposition rule:
|
// To avoid overflow we use the following decomposition rule:
|
||||||
|
@ -5798,6 +6222,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
||||||
|
|
|
@ -385,6 +385,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenReshapeOp>();
|
target.addIllegalOp<AtenReshapeOp>();
|
||||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||||
target.addIllegalOp<AtenTanhBackwardOp>();
|
target.addIllegalOp<AtenTanhBackwardOp>();
|
||||||
|
target.addIllegalOp<AtenEinsumOp>();
|
||||||
target.addIllegalOp<AtenAddmmOp>();
|
target.addIllegalOp<AtenAddmmOp>();
|
||||||
target.addIllegalOp<AtenMeanOp>();
|
target.addIllegalOp<AtenMeanOp>();
|
||||||
target.addIllegalOp<AtenMeanDimOp>();
|
target.addIllegalOp<AtenMeanDimOp>();
|
||||||
|
|
|
@ -554,6 +554,9 @@ STABLEHLO_PASS_SET = {
|
||||||
"EmptyLikeModule_int",
|
"EmptyLikeModule_int",
|
||||||
"ExpandAsIntModule_basic",
|
"ExpandAsIntModule_basic",
|
||||||
"ExpandModule_basic",
|
"ExpandModule_basic",
|
||||||
|
"EinsumStaticModule_basic",
|
||||||
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"Fill_TensorFloat64WithFloat32_basic",
|
"Fill_TensorFloat64WithFloat32_basic",
|
||||||
"Fill_TensorFloat64WithFloat64_basic",
|
"Fill_TensorFloat64WithFloat64_basic",
|
||||||
"Fill_TensorFloat64WithInt64_basic",
|
"Fill_TensorFloat64WithInt64_basic",
|
||||||
|
@ -1020,6 +1023,9 @@ TOSA_PASS_SET = {
|
||||||
"RsubFloatModule_basic",
|
"RsubFloatModule_basic",
|
||||||
"RsubFloatModule_noalpha_basic",
|
"RsubFloatModule_noalpha_basic",
|
||||||
"RsubInt0d_NumToTensor_Module_basic",
|
"RsubInt0d_NumToTensor_Module_basic",
|
||||||
|
"EinsumStaticModule_basic",
|
||||||
|
"EinsumStaticFourDimensionModule_basic",
|
||||||
|
"EinsumStaticContractRhsModule_basic",
|
||||||
"ElementwiseBitwiseAndModule_basic",
|
"ElementwiseBitwiseAndModule_basic",
|
||||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||||
"ElementwiseBitwiseNotInt32Module_basic",
|
"ElementwiseBitwiseNotInt32Module_basic",
|
||||||
|
|
|
@ -3684,6 +3684,19 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0)
|
||||||
dtypes.append(tensor_dtype)
|
dtypes.append(tensor_dtype)
|
||||||
return promote_dtypes(ranks, dtypes)
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
|
||||||
|
TensorOfShape(1, dtype=torch.int32)]),])
|
||||||
|
def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int:
|
||||||
|
ranks: List[Optional[int]] = []
|
||||||
|
dtypes: List[int] = []
|
||||||
|
assert len(tensors_rank_dtype) != 0
|
||||||
|
for tensor_rank_dtype in tensors_rank_dtype:
|
||||||
|
tensor_rank, tensor_dtype = tensor_rank_dtype
|
||||||
|
ranks.append(tensor_rank)
|
||||||
|
dtypes.append(tensor_dtype)
|
||||||
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
|
@ -566,6 +566,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||||
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
||||||
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
|
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
|
||||||
|
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
|
||||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||||
emit("aten::clone : (Tensor, int?) -> (Tensor)")
|
emit("aten::clone : (Tensor, int?) -> (Tensor)")
|
||||||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -1044,3 +1044,59 @@ class UnflattenIntNegativeOneSizeStaticModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule())
|
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule())
|
||||||
def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
|
def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(5, 12, 3))
|
module.forward(tu.rand(5, 12, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class EinsumStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 2, 4], torch.float32, True),
|
||||||
|
([5, 4, 6], torch.float32, True),
|
||||||
|
([3, 7, 6], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor1, tensor2, tensor3):
|
||||||
|
return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EinsumStaticModule())
|
||||||
|
def EinsumStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 2, 4), tu.rand(5, 4, 6), tu.rand(3, 7, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class EinsumStaticFourDimensionModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 5, 6], torch.float32, True),
|
||||||
|
([3, 7, 5, 6], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor1, tensor2):
|
||||||
|
return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule())
|
||||||
|
def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class EinsumStaticContractRhsModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 5], torch.float32, True),
|
||||||
|
([4, 5], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, tensor1, tensor2):
|
||||||
|
return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2])
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
|
||||||
|
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
|
Loading…
Reference in New Issue