mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Support Einsum Op (#2230)
As title, support torch.aten.einsum op Right now only support Static Shape, because of the known issue, the fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154 Co-authored-by: Jiawei Wu [wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com)pull/2628/head snapshot-20231210.1048
parent
07c3e11f56
commit
96fcde4d77
|
@ -8447,6 +8447,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_StringType:$equation,
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
AnyTorchOptionalListOfTorchIntType:$path
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenEinsumOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
|
||||
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
|
||||
" %2 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %3 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %4 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
|
||||
" torch.prim.Loop %4, %true, init() {\n"
|
||||
" ^bb0(%arg3: !torch.int):\n"
|
||||
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
|
||||
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
|
|
|
@ -187,6 +187,358 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
|
|||
return dimsOrder;
|
||||
}
|
||||
|
||||
static bool parseEquation(const std::string &equation,
|
||||
SmallVector<SmallVector<char>> &inputTokens,
|
||||
SmallVector<char> &resultTokens) {
|
||||
SmallVector<char> inputToken;
|
||||
size_t index = 0;
|
||||
enum EquationVariable { kIsInput, kIsResult };
|
||||
EquationVariable currentVariable = kIsInput;
|
||||
while (index < equation.size()) {
|
||||
if (std::isalpha(equation[index])) {
|
||||
if (currentVariable == kIsInput) {
|
||||
inputToken.push_back(equation[index]);
|
||||
} else {
|
||||
resultTokens.push_back(equation[index]);
|
||||
}
|
||||
} else if (equation[index] == ',') {
|
||||
inputTokens.push_back(inputToken);
|
||||
inputToken.clear();
|
||||
} else if ((index < (equation.size() - 1)) &&
|
||||
(equation.substr(index, 2).find("->") != std::string::npos)) {
|
||||
inputTokens.push_back(inputToken);
|
||||
inputToken.clear();
|
||||
currentVariable = kIsResult;
|
||||
index++;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
index++;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
|
||||
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
||||
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||
Value input, int64_t batchDimsLength,
|
||||
int64_t contractingDimsLength,
|
||||
int64_t otherDimsLength,
|
||||
int64_t reduceDimsLength, bool isLhs) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||
reduceDimsLength;
|
||||
SmallVector<Value> inputShapeTensor;
|
||||
for (auto i = 0; i < inputRank; ++i) {
|
||||
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
|
||||
loc, input,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(i))));
|
||||
}
|
||||
|
||||
SmallVector<Value> outShapeTensor;
|
||||
Value constOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto dimOffset = 0;
|
||||
|
||||
auto appendDims = [&](int64_t dimLength) {
|
||||
Value prod = constOne;
|
||||
for (auto i = 0; i < dimLength; ++i) {
|
||||
prod = rewriter.create<AtenMulIntOp>(loc, prod,
|
||||
inputShapeTensor[i + dimOffset]);
|
||||
}
|
||||
outShapeTensor.emplace_back(prod);
|
||||
dimOffset += dimLength;
|
||||
};
|
||||
|
||||
appendDims(batchDimsLength);
|
||||
if (!isLhs)
|
||||
appendDims(contractingDimsLength);
|
||||
appendDims(otherDimsLength + reduceDimsLength);
|
||||
if (isLhs)
|
||||
appendDims(contractingDimsLength);
|
||||
|
||||
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||
outShapeTensor);
|
||||
|
||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||
inputType.getOptionalDtype());
|
||||
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
|
||||
outShapeValue);
|
||||
}
|
||||
|
||||
// classify every dim token into different categories. Note that although we
|
||||
// parse out reduce dims, we delay their execution until
|
||||
// `performLastPermuteAndReduce`.
|
||||
static void parseDimTokens(
|
||||
SmallVector<char> &lhsTokens, SmallVector<char> &rhsTokens,
|
||||
SmallVector<char> &finalResultTokens, SmallVector<char> &contractingDims,
|
||||
SmallVector<char> &lhsReduceDims, SmallVector<char> &rhsReduceDims,
|
||||
SmallVector<char> &batchingDims, SmallVector<char> &lhsOtherDims,
|
||||
SmallVector<char> &rhsOtherDims) {
|
||||
llvm::SmallDenseSet<char> lhsTokenSet(lhsTokens.begin(), lhsTokens.end());
|
||||
llvm::SmallDenseSet<char> rhsTokenSet(rhsTokens.begin(), rhsTokens.end());
|
||||
llvm::SmallDenseSet<char> finalResultTokenSet(finalResultTokens.begin(),
|
||||
finalResultTokens.end());
|
||||
|
||||
for (size_t i = 0; i < lhsTokens.size(); ++i) {
|
||||
bool rhsContains = rhsTokenSet.contains(lhsTokens[i]);
|
||||
bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]);
|
||||
// batching dim
|
||||
if (rhsContains && finalResultConatins) {
|
||||
batchingDims.push_back(lhsTokens[i]);
|
||||
// reduce dim of lhs
|
||||
} else if (!rhsContains && !finalResultConatins) {
|
||||
lhsReduceDims.push_back(lhsTokens[i]);
|
||||
// other dim of lhs
|
||||
} else if (finalResultConatins) {
|
||||
lhsOtherDims.push_back(lhsTokens[i]);
|
||||
// contracting dim of lhs
|
||||
} else if (rhsContains) {
|
||||
contractingDims.push_back(lhsTokens[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < rhsTokens.size(); ++i) {
|
||||
bool lhsContains = lhsTokenSet.contains(rhsTokens[i]);
|
||||
bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]);
|
||||
// batching dim
|
||||
if (lhsContains && finalResultConatins) {
|
||||
// reduce dim of rhs
|
||||
} else if (!lhsContains && !finalResultConatins) {
|
||||
rhsReduceDims.push_back(rhsTokens[i]);
|
||||
// other dim of rhs
|
||||
} else if (finalResultConatins) {
|
||||
rhsOtherDims.push_back(rhsTokens[i]);
|
||||
// contracting dim of rhs
|
||||
} else if (lhsContains) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void generateIdealReusltDimTokens(SmallVector<char> &batchingDims,
|
||||
SmallVector<char> &lhsOtherDims,
|
||||
SmallVector<char> &rhsOtherDims,
|
||||
SmallVector<char> &lhsReduceDims,
|
||||
SmallVector<char> &rhsReduceDims,
|
||||
SmallVector<char> &resultTokens) {
|
||||
// generate ideal result dims, i.e.,
|
||||
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
||||
// *rhsReduceDims]
|
||||
resultTokens.insert(resultTokens.end(), batchingDims.begin(),
|
||||
batchingDims.end());
|
||||
resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(),
|
||||
lhsOtherDims.end());
|
||||
resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(),
|
||||
lhsReduceDims.end());
|
||||
resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(),
|
||||
rhsOtherDims.end());
|
||||
resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(),
|
||||
rhsReduceDims.end());
|
||||
}
|
||||
|
||||
static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
||||
Value input, SmallVector<char> &dimTokens,
|
||||
SmallVector<char> &batchingDims,
|
||||
SmallVector<char> &contractingDims,
|
||||
SmallVector<char> &otherDims,
|
||||
SmallVector<char> &reduceDims, bool isLhs) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||
dimTokenMap[dimTokens[idx]] = idx;
|
||||
}
|
||||
|
||||
SmallVector<Value> permuteVec;
|
||||
auto appendDims = [&](SmallVector<char> dimTokens) {
|
||||
for (auto d : dimTokens) {
|
||||
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
|
||||
}
|
||||
};
|
||||
|
||||
appendDims(batchingDims);
|
||||
if (!isLhs)
|
||||
appendDims(contractingDims);
|
||||
appendDims(otherDims);
|
||||
appendDims(reduceDims);
|
||||
if (isLhs)
|
||||
appendDims(contractingDims);
|
||||
|
||||
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||
permuteVec);
|
||||
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
|
||||
inputType.getOptionalDtype());
|
||||
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
|
||||
}
|
||||
|
||||
static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||
Value lhs, SmallVector<char> &lhsTokens,
|
||||
Value rhs, SmallVector<char> &rhsTokens,
|
||||
Value &result,
|
||||
SmallVector<char> &resultTokens,
|
||||
SmallVector<char> &finalResultTokens) {
|
||||
auto lhsType = lhs.getType().cast<BaseTensorType>();
|
||||
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
||||
|
||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||
: rhsType.getOptionalDtype();
|
||||
|
||||
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
||||
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
||||
char d = lhsTokens[idx];
|
||||
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||
loc, lhs,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(idx)));
|
||||
}
|
||||
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
|
||||
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
|
||||
char d = rhsTokens[idx];
|
||||
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||
loc, rhs,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(idx)));
|
||||
}
|
||||
|
||||
// parse batch, contracting, other, reduce dims of lhs and rhs
|
||||
SmallVector<char> contractingDims;
|
||||
SmallVector<char> lhsReduceDims;
|
||||
SmallVector<char> rhsReduceDims;
|
||||
SmallVector<char> lhsOtherDims;
|
||||
SmallVector<char> rhsOtherDims;
|
||||
SmallVector<char> batchingDims;
|
||||
parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims,
|
||||
lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims,
|
||||
rhsOtherDims);
|
||||
|
||||
llvm::SmallDenseMap<char, Value> outDimShapeMap;
|
||||
auto generateOutDimShapeMap = [&](SmallVector<char> &dims) {
|
||||
for (auto d : dims) {
|
||||
bool lhsContains = lhsDimShapeMap.count(d) > 0;
|
||||
bool rhsContains = rhsDimShapeMap.count(d) > 0;
|
||||
if (lhsContains && rhsContains) {
|
||||
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
|
||||
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
|
||||
} else if (lhsContains) {
|
||||
outDimShapeMap[d] = lhsDimShapeMap[d];
|
||||
} else if (rhsContains) {
|
||||
outDimShapeMap[d] = rhsDimShapeMap[d];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
generateOutDimShapeMap(contractingDims);
|
||||
generateOutDimShapeMap(batchingDims);
|
||||
generateOutDimShapeMap(lhsReduceDims);
|
||||
generateOutDimShapeMap(rhsReduceDims);
|
||||
generateOutDimShapeMap(lhsOtherDims);
|
||||
generateOutDimShapeMap(rhsOtherDims);
|
||||
|
||||
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
|
||||
rhsOtherDims.size() == 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
loc, "Hadamard product is currently not supported");
|
||||
}
|
||||
|
||||
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
|
||||
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
|
||||
contractingDims, lhsOtherDims, lhsReduceDims,
|
||||
true);
|
||||
// shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims]
|
||||
rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims,
|
||||
contractingDims, rhsOtherDims, rhsReduceDims,
|
||||
false);
|
||||
// shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
|
||||
lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(),
|
||||
contractingDims.size(), lhsOtherDims.size(),
|
||||
lhsReduceDims.size(), true);
|
||||
// shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd]
|
||||
rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(),
|
||||
contractingDims.size(), rhsOtherDims.size(),
|
||||
rhsReduceDims.size(), false);
|
||||
|
||||
// perform matmul
|
||||
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
|
||||
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
|
||||
|
||||
// generate ideal result dims.
|
||||
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
|
||||
lhsReduceDims, rhsReduceDims, resultTokens);
|
||||
|
||||
// reshape matmul result to ideal shape:
|
||||
// [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] =>
|
||||
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
|
||||
// *rhsReduceDims]
|
||||
SmallVector<Value> outShapeTensors;
|
||||
for (char d : resultTokens) {
|
||||
outShapeTensors.emplace_back(outDimShapeMap[d]);
|
||||
}
|
||||
|
||||
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
|
||||
outShapeTensors);
|
||||
result = rewriter.create<Torch::AtenReshapeOp>(
|
||||
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
|
||||
outResultShape);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
||||
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
||||
Location loc, Type outType,
|
||||
Value input,
|
||||
SmallVector<char> &inputTokens,
|
||||
SmallVector<char> &outTokens) {
|
||||
auto inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
||||
SmallVector<int64_t> sumDims;
|
||||
llvm::SmallDenseMap<char, int64_t> inputDimToIdx;
|
||||
int64_t idx = 0;
|
||||
for (size_t i = 0; i < inputTokens.size(); ++i) {
|
||||
char d = inputTokens[i];
|
||||
if (!outTokenSet.contains(d)) {
|
||||
sumDims.emplace_back(i);
|
||||
} else {
|
||||
inputDimToIdx[d] = idx++;
|
||||
}
|
||||
}
|
||||
|
||||
if (sumDims.size() > 0) {
|
||||
SmallVector<Value> sumDimsTensor;
|
||||
for (auto d : sumDims) {
|
||||
sumDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(d)));
|
||||
}
|
||||
auto sumDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
|
||||
sumDimsTensor);
|
||||
auto falseValue = rewriter.create<Torch::ConstantBoolOp>(
|
||||
loc, rewriter.getBoolAttr(false));
|
||||
auto noneValue = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
input = rewriter.create<Torch::AtenSumDimIntListOp>(
|
||||
loc,
|
||||
inputType.getWithSizesAndDtype(std::nullopt,
|
||||
inputType.getOptionalDtype()),
|
||||
input, sumDimsListValue, falseValue, noneValue);
|
||||
}
|
||||
|
||||
SmallVector<Value> permuteDimsTensor;
|
||||
for (auto d : outTokens) {
|
||||
permuteDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inputDimToIdx[d])));
|
||||
}
|
||||
auto permuteDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
|
||||
permuteDimsTensor);
|
||||
auto out = rewriter.create<Torch::AtenPermuteOp>(loc, outType, input,
|
||||
permuteDimsListValue);
|
||||
return out;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
|
||||
/// number of dimensions across which the max needs to be computed.
|
||||
|
@ -628,6 +980,78 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
|
||||
// operation and permute operation. Currently, this pass doesn't support
|
||||
// Hadamard product. The basic idea is that:
|
||||
// Step 1: split the string equation to input/result tokens and find
|
||||
// batchingDims, contractingDims, otherDims and reduceDims.
|
||||
// Step 2: permute and reshape input tensors suitable
|
||||
// for matmul operations.
|
||||
// Step 3: use AtenMatmulOp to get the result.
|
||||
// Step 4: iteratively execute step 2 & 3 until we get the final result.
|
||||
// Step 5: perform remaining permute and reduce operations.
|
||||
// notice: support static shape only
|
||||
|
||||
class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenEinsumOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
std::string equation;
|
||||
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
|
||||
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
|
||||
}
|
||||
SmallVector<char> resultTokens;
|
||||
SmallVector<SmallVector<char>> inputTokens;
|
||||
if (!parseEquation(equation, inputTokens, resultTokens)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unexpected character in equations encountered");
|
||||
}
|
||||
|
||||
SmallVector<Value> inputTensors;
|
||||
if (!getListConstructElements(op.getTensors(), inputTensors)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input should comes from a PrimListConstructOp");
|
||||
}
|
||||
|
||||
auto allTensorHasSizes = [](Value tensor) {
|
||||
auto type = tensor.getType().dyn_cast<BaseTensorType>();
|
||||
if (!type || !type.hasSizes())
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!llvm::all_of(inputTensors, allTensorHasSizes)) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"all input tensors should have sizes");
|
||||
}
|
||||
|
||||
SmallVector<char> lhsTokens = inputTokens[0];
|
||||
Value lhs = inputTensors[0];
|
||||
Value result;
|
||||
|
||||
for (size_t i = 1; i < inputTensors.size(); ++i) {
|
||||
auto rhs = inputTensors[i];
|
||||
auto rhsTokens = inputTokens[i];
|
||||
SmallVector<char> outTokens;
|
||||
if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens,
|
||||
result, outTokens, resultTokens))) {
|
||||
return failure();
|
||||
}
|
||||
lhs = result;
|
||||
lhsTokens = outTokens;
|
||||
}
|
||||
|
||||
result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs,
|
||||
lhsTokens, resultTokens);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
|
||||
// exp(x)/sum(exp(x)).
|
||||
// To avoid overflow we use the following decomposition rule:
|
||||
|
@ -5798,6 +6222,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);
|
||||
|
|
|
@ -385,6 +385,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenReshapeOp>();
|
||||
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
|
||||
target.addIllegalOp<AtenTanhBackwardOp>();
|
||||
target.addIllegalOp<AtenEinsumOp>();
|
||||
target.addIllegalOp<AtenAddmmOp>();
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
target.addIllegalOp<AtenMeanDimOp>();
|
||||
|
|
|
@ -554,6 +554,9 @@ STABLEHLO_PASS_SET = {
|
|||
"EmptyLikeModule_int",
|
||||
"ExpandAsIntModule_basic",
|
||||
"ExpandModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticContractRhsModule_basic",
|
||||
"Fill_TensorFloat64WithFloat32_basic",
|
||||
"Fill_TensorFloat64WithFloat64_basic",
|
||||
"Fill_TensorFloat64WithInt64_basic",
|
||||
|
@ -1020,6 +1023,9 @@ TOSA_PASS_SET = {
|
|||
"RsubFloatModule_basic",
|
||||
"RsubFloatModule_noalpha_basic",
|
||||
"RsubInt0d_NumToTensor_Module_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticContractRhsModule_basic",
|
||||
"ElementwiseBitwiseAndModule_basic",
|
||||
"ElementwiseBitwiseAndStaticShapeModule_basic",
|
||||
"ElementwiseBitwiseNotInt32Module_basic",
|
||||
|
|
|
@ -3684,6 +3684,19 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0)
|
|||
dtypes.append(tensor_dtype)
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
|
||||
TensorOfShape(1, dtype=torch.int32)]),])
|
||||
def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int:
|
||||
ranks: List[Optional[int]] = []
|
||||
dtypes: List[int] = []
|
||||
assert len(tensors_rank_dtype) != 0
|
||||
for tensor_rank_dtype in tensors_rank_dtype:
|
||||
tensor_rank, tensor_dtype = tensor_rank_dtype
|
||||
ranks.append(tensor_rank)
|
||||
dtypes.append(tensor_dtype)
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||
def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
return torch.int64
|
||||
|
|
|
@ -566,6 +566,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
|
||||
emit("aten::one_hot : (Tensor, int) -> (Tensor)")
|
||||
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
|
||||
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
|
||||
emit("aten::clone : (Tensor, int?) -> (Tensor)")
|
||||
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")
|
||||
|
|
|
@ -1044,3 +1044,59 @@ class UnflattenIntNegativeOneSizeStaticModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule())
|
||||
def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 12, 3))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class EinsumStaticModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 2, 4], torch.float32, True),
|
||||
([5, 4, 6], torch.float32, True),
|
||||
([3, 7, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor1, tensor2, tensor3):
|
||||
return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3])
|
||||
|
||||
@register_test_case(module_factory=lambda: EinsumStaticModule())
|
||||
def EinsumStaticModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 2, 4), tu.rand(5, 4, 6), tu.rand(3, 7, 6))
|
||||
|
||||
|
||||
class EinsumStaticFourDimensionModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5, 6], torch.float32, True),
|
||||
([3, 7, 5, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor1, tensor2):
|
||||
return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2])
|
||||
|
||||
@register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule())
|
||||
def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6))
|
||||
|
||||
|
||||
class EinsumStaticContractRhsModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4, 5], torch.float32, True),
|
||||
([4, 5], torch.float32, True),
|
||||
])
|
||||
def forward(self, tensor1, tensor2):
|
||||
return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2])
|
||||
|
||||
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
|
||||
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
|
Loading…
Reference in New Issue