[Torch Dialect] Support Einsum Op (#2230)

As title, support torch.aten.einsum op

Right now only support Static Shape, because of the known issue, the
fixed solution is here: https://github.com/llvm/torch-mlir/pull/2154

Co-authored-by: Jiawei Wu
[wujiawei.aml@bytedance.com](mailto:wujiawei.aml@bytedance.com)
pull/2628/head snapshot-20231210.1048
JianzheXiao 2023-12-09 20:30:37 -08:00 committed by GitHub
parent 07c3e11f56
commit 96fcde4d77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 554 additions and 0 deletions

View File

@ -8447,6 +8447,31 @@ def Torch_AtenOneHotOp : Torch_Op<"aten.one_hot", [
}]; }];
} }
def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::einsum : (str, Tensor[], int[]?) -> (Tensor)`";
let arguments = (ins
Torch_StringType:$equation,
AnyTorchListOfTensorType:$tensors,
AnyTorchOptionalListOfTorchIntType:$path
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEinsumOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenEinsumOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -11321,6 +11321,33 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n" " return %5 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list<tuple<int, int>>, %arg2: !torch.optional<list<int>>) -> !torch.int {\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<optional<int>>\n"
" %1 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.len.t %arg1 : !torch.list<tuple<int, int>> -> !torch.int\n"
" torch.prim.Loop %4, %true, init() {\n"
" ^bb0(%arg3: !torch.int):\n"
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<tuple<int, int>>, !torch.int -> !torch.tuple<int, int>\n"
" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %8 = torch.aten.append.t %0, %7#0 : !torch.list<optional<int>>, !torch.int -> !torch.list<optional<int>>\n"
" %9 = torch.aten.append.t %1, %7#1 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int4 = torch.constant.int 4\n" " %int4 = torch.constant.int 4\n"
" return %int4 : !torch.int\n" " return %int4 : !torch.int\n"

View File

@ -187,6 +187,358 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
return dimsOrder; return dimsOrder;
} }
static bool parseEquation(const std::string &equation,
SmallVector<SmallVector<char>> &inputTokens,
SmallVector<char> &resultTokens) {
SmallVector<char> inputToken;
size_t index = 0;
enum EquationVariable { kIsInput, kIsResult };
EquationVariable currentVariable = kIsInput;
while (index < equation.size()) {
if (std::isalpha(equation[index])) {
if (currentVariable == kIsInput) {
inputToken.push_back(equation[index]);
} else {
resultTokens.push_back(equation[index]);
}
} else if (equation[index] == ',') {
inputTokens.push_back(inputToken);
inputToken.clear();
} else if ((index < (equation.size() - 1)) &&
(equation.substr(index, 2).find("->") != std::string::npos)) {
inputTokens.push_back(inputToken);
inputToken.clear();
currentVariable = kIsResult;
index++;
} else {
return false;
}
index++;
}
return true;
}
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims] =>
// [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
Value input, int64_t batchDimsLength,
int64_t contractingDimsLength,
int64_t otherDimsLength,
int64_t reduceDimsLength, bool isLhs) {
auto inputType = input.getType().cast<BaseTensorType>();
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
reduceDimsLength;
SmallVector<Value> inputShapeTensor;
for (auto i = 0; i < inputRank; ++i) {
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
loc, input,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(i))));
}
SmallVector<Value> outShapeTensor;
Value constOne =
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto dimOffset = 0;
auto appendDims = [&](int64_t dimLength) {
Value prod = constOne;
for (auto i = 0; i < dimLength; ++i) {
prod = rewriter.create<AtenMulIntOp>(loc, prod,
inputShapeTensor[i + dimOffset]);
}
outShapeTensor.emplace_back(prod);
dimOffset += dimLength;
};
appendDims(batchDimsLength);
if (!isLhs)
appendDims(contractingDimsLength);
appendDims(otherDimsLength + reduceDimsLength);
if (isLhs)
appendDims(contractingDimsLength);
auto outShapeValue = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
outShapeTensor);
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
inputType.getOptionalDtype());
return rewriter.create<Torch::AtenReshapeOp>(loc, outType, input,
outShapeValue);
}
// classify every dim token into different categories. Note that although we
// parse out reduce dims, we delay their execution until
// `performLastPermuteAndReduce`.
static void parseDimTokens(
SmallVector<char> &lhsTokens, SmallVector<char> &rhsTokens,
SmallVector<char> &finalResultTokens, SmallVector<char> &contractingDims,
SmallVector<char> &lhsReduceDims, SmallVector<char> &rhsReduceDims,
SmallVector<char> &batchingDims, SmallVector<char> &lhsOtherDims,
SmallVector<char> &rhsOtherDims) {
llvm::SmallDenseSet<char> lhsTokenSet(lhsTokens.begin(), lhsTokens.end());
llvm::SmallDenseSet<char> rhsTokenSet(rhsTokens.begin(), rhsTokens.end());
llvm::SmallDenseSet<char> finalResultTokenSet(finalResultTokens.begin(),
finalResultTokens.end());
for (size_t i = 0; i < lhsTokens.size(); ++i) {
bool rhsContains = rhsTokenSet.contains(lhsTokens[i]);
bool finalResultConatins = finalResultTokenSet.contains(lhsTokens[i]);
// batching dim
if (rhsContains && finalResultConatins) {
batchingDims.push_back(lhsTokens[i]);
// reduce dim of lhs
} else if (!rhsContains && !finalResultConatins) {
lhsReduceDims.push_back(lhsTokens[i]);
// other dim of lhs
} else if (finalResultConatins) {
lhsOtherDims.push_back(lhsTokens[i]);
// contracting dim of lhs
} else if (rhsContains) {
contractingDims.push_back(lhsTokens[i]);
}
}
for (size_t i = 0; i < rhsTokens.size(); ++i) {
bool lhsContains = lhsTokenSet.contains(rhsTokens[i]);
bool finalResultConatins = finalResultTokenSet.contains(rhsTokens[i]);
// batching dim
if (lhsContains && finalResultConatins) {
// reduce dim of rhs
} else if (!lhsContains && !finalResultConatins) {
rhsReduceDims.push_back(rhsTokens[i]);
// other dim of rhs
} else if (finalResultConatins) {
rhsOtherDims.push_back(rhsTokens[i]);
// contracting dim of rhs
} else if (lhsContains) {
}
}
}
static void generateIdealReusltDimTokens(SmallVector<char> &batchingDims,
SmallVector<char> &lhsOtherDims,
SmallVector<char> &rhsOtherDims,
SmallVector<char> &lhsReduceDims,
SmallVector<char> &rhsReduceDims,
SmallVector<char> &resultTokens) {
// generate ideal result dims, i.e.,
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
// *rhsReduceDims]
resultTokens.insert(resultTokens.end(), batchingDims.begin(),
batchingDims.end());
resultTokens.insert(resultTokens.end(), lhsOtherDims.begin(),
lhsOtherDims.end());
resultTokens.insert(resultTokens.end(), lhsReduceDims.begin(),
lhsReduceDims.end());
resultTokens.insert(resultTokens.end(), rhsOtherDims.begin(),
rhsOtherDims.end());
resultTokens.insert(resultTokens.end(), rhsReduceDims.begin(),
rhsReduceDims.end());
}
static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
Value input, SmallVector<char> &dimTokens,
SmallVector<char> &batchingDims,
SmallVector<char> &contractingDims,
SmallVector<char> &otherDims,
SmallVector<char> &reduceDims, bool isLhs) {
auto inputType = input.getType().cast<BaseTensorType>();
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
dimTokenMap[dimTokens[idx]] = idx;
}
SmallVector<Value> permuteVec;
auto appendDims = [&](SmallVector<char> dimTokens) {
for (auto d : dimTokens) {
permuteVec.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimTokenMap[d])));
}
};
appendDims(batchingDims);
if (!isLhs)
appendDims(contractingDims);
appendDims(otherDims);
appendDims(reduceDims);
if (isLhs)
appendDims(contractingDims);
Value dstDims = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
permuteVec);
auto outType = inputType.getWithSizesAndDtype(std::nullopt,
inputType.getOptionalDtype());
return rewriter.create<Torch::AtenPermuteOp>(loc, outType, input, dstDims);
}
static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
Value lhs, SmallVector<char> &lhsTokens,
Value rhs, SmallVector<char> &rhsTokens,
Value &result,
SmallVector<char> &resultTokens,
SmallVector<char> &finalResultTokens) {
auto lhsType = lhs.getType().cast<BaseTensorType>();
auto rhsType = rhs.getType().cast<BaseTensorType>();
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
: rhsType.getOptionalDtype();
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
char d = lhsTokens[idx];
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
loc, lhs,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx)));
}
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
char d = rhsTokens[idx];
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
loc, rhs,
rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx)));
}
// parse batch, contracting, other, reduce dims of lhs and rhs
SmallVector<char> contractingDims;
SmallVector<char> lhsReduceDims;
SmallVector<char> rhsReduceDims;
SmallVector<char> lhsOtherDims;
SmallVector<char> rhsOtherDims;
SmallVector<char> batchingDims;
parseDimTokens(lhsTokens, rhsTokens, finalResultTokens, contractingDims,
lhsReduceDims, rhsReduceDims, batchingDims, lhsOtherDims,
rhsOtherDims);
llvm::SmallDenseMap<char, Value> outDimShapeMap;
auto generateOutDimShapeMap = [&](SmallVector<char> &dims) {
for (auto d : dims) {
bool lhsContains = lhsDimShapeMap.count(d) > 0;
bool rhsContains = rhsDimShapeMap.count(d) > 0;
if (lhsContains && rhsContains) {
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
} else if (lhsContains) {
outDimShapeMap[d] = lhsDimShapeMap[d];
} else if (rhsContains) {
outDimShapeMap[d] = rhsDimShapeMap[d];
}
}
};
generateOutDimShapeMap(contractingDims);
generateOutDimShapeMap(batchingDims);
generateOutDimShapeMap(lhsReduceDims);
generateOutDimShapeMap(rhsReduceDims);
generateOutDimShapeMap(lhsOtherDims);
generateOutDimShapeMap(rhsOtherDims);
if (contractingDims.size() == 0 && lhsOtherDims.size() == 0 &&
rhsOtherDims.size() == 0) {
return rewriter.notifyMatchFailure(
loc, "Hadamard product is currently not supported");
}
// shape: [*batchingDims, *lhsOtherDims, *lhsReduceDims, *lhsContractingDims]
lhs = permuteTensorForMatmul(rewriter, loc, lhs, lhsTokens, batchingDims,
contractingDims, lhsOtherDims, lhsReduceDims,
true);
// shape: [*batchingDims, *rhsContractingDims, *rhsOtherDims, *rhsReduceDims]
rhs = permuteTensorForMatmul(rewriter, loc, rhs, rhsTokens, batchingDims,
contractingDims, rhsOtherDims, rhsReduceDims,
false);
// shape: [batchingDimsProd, lhsOtherDimsProd, lhsContractingDimsProd]
lhs = collapseDimForMatmul(rewriter, loc, lhs, batchingDims.size(),
contractingDims.size(), lhsOtherDims.size(),
lhsReduceDims.size(), true);
// shape: [batchingDimsProd, rhsContractingDimsProd, rhsOtherDimsProd]
rhs = collapseDimForMatmul(rewriter, loc, rhs, batchingDims.size(),
contractingDims.size(), rhsOtherDims.size(),
rhsReduceDims.size(), false);
// perform matmul
auto outType = lhsType.getWithSizesAndDtype(std::nullopt, outputDType);
result = rewriter.create<Torch::AtenMatmulOp>(loc, outType, lhs, rhs);
// generate ideal result dims.
generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims,
lhsReduceDims, rhsReduceDims, resultTokens);
// reshape matmul result to ideal shape:
// [batchingDimsProd, lhsOtherDimsProd, rhsOtherDimsProd] =>
// [*batchingDims, *lhsOtherDims, *lhsReduceDims, *rhsOtherDims,
// *rhsReduceDims]
SmallVector<Value> outShapeTensors;
for (char d : resultTokens) {
outShapeTensors.emplace_back(outDimShapeMap[d]);
}
auto outResultShape = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())),
outShapeTensors);
result = rewriter.create<Torch::AtenReshapeOp>(
loc, lhsType.getWithSizesAndDtype(std::nullopt, outputDType), result,
outResultShape);
return success();
}
static Value performLastReduceAndPermute(PatternRewriter &rewriter,
Location loc, Type outType,
Value input,
SmallVector<char> &inputTokens,
SmallVector<char> &outTokens) {
auto inputType = input.getType().cast<BaseTensorType>();
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
SmallVector<int64_t> sumDims;
llvm::SmallDenseMap<char, int64_t> inputDimToIdx;
int64_t idx = 0;
for (size_t i = 0; i < inputTokens.size(); ++i) {
char d = inputTokens[i];
if (!outTokenSet.contains(d)) {
sumDims.emplace_back(i);
} else {
inputDimToIdx[d] = idx++;
}
}
if (sumDims.size() > 0) {
SmallVector<Value> sumDimsTensor;
for (auto d : sumDims) {
sumDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d)));
}
auto sumDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())),
sumDimsTensor);
auto falseValue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));
auto noneValue = rewriter.create<Torch::ConstantNoneOp>(loc);
input = rewriter.create<Torch::AtenSumDimIntListOp>(
loc,
inputType.getWithSizesAndDtype(std::nullopt,
inputType.getOptionalDtype()),
input, sumDimsListValue, falseValue, noneValue);
}
SmallVector<Value> permuteDimsTensor;
for (auto d : outTokens) {
permuteDimsTensor.emplace_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputDimToIdx[d])));
}
auto permuteDimsListValue = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(input.getContext())),
permuteDimsTensor);
auto out = rewriter.create<Torch::AtenPermuteOp>(loc, outType, input,
permuteDimsListValue);
return out;
}
namespace { namespace {
/// We decompose aten.amax into a set of aten.max.dim op(s) depending on the /// We decompose aten.amax into a set of aten.max.dim op(s) depending on the
/// number of dimensions across which the max needs to be computed. /// number of dimensions across which the max needs to be computed.
@ -628,6 +980,78 @@ public:
}; };
} // namespace } // namespace
namespace {
// Decompose AtenEinsumOp to AtenMatmulOp, and supports possible reduce
// operation and permute operation. Currently, this pass doesn't support
// Hadamard product. The basic idea is that:
// Step 1: split the string equation to input/result tokens and find
// batchingDims, contractingDims, otherDims and reduceDims.
// Step 2: permute and reshape input tensors suitable
// for matmul operations.
// Step 3: use AtenMatmulOp to get the result.
// Step 4: iteratively execute step 2 & 3 until we get the final result.
// Step 5: perform remaining permute and reduce operations.
// notice: support static shape only
class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEinsumOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
std::string equation;
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}
SmallVector<Value> inputTensors;
if (!getListConstructElements(op.getTensors(), inputTensors)) {
return rewriter.notifyMatchFailure(
op, "input should comes from a PrimListConstructOp");
}
auto allTensorHasSizes = [](Value tensor) {
auto type = tensor.getType().dyn_cast<BaseTensorType>();
if (!type || !type.hasSizes())
return false;
return true;
};
if (!llvm::all_of(inputTensors, allTensorHasSizes)) {
return rewriter.notifyMatchFailure(op,
"all input tensors should have sizes");
}
SmallVector<char> lhsTokens = inputTokens[0];
Value lhs = inputTensors[0];
Value result;
for (size_t i = 1; i < inputTensors.size(); ++i) {
auto rhs = inputTensors[i];
auto rhsTokens = inputTokens[i];
SmallVector<char> outTokens;
if (failed(performMatmul(rewriter, loc, lhs, lhsTokens, rhs, rhsTokens,
result, outTokens, resultTokens))) {
return failure();
}
lhs = result;
lhsTokens = outTokens;
}
result = performLastReduceAndPermute(rewriter, loc, op.getType(), lhs,
lhsTokens, resultTokens);
rewriter.replaceOp(op, result);
return success();
}
};
} // namespace
// Calculates the softmax function on the given `input` tensor. Softmax(x) = // Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)). // exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule: // To avoid overflow we use the following decomposition rule:
@ -5798,6 +6222,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSiluOp>(patterns);

View File

@ -385,6 +385,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenReshapeOp>(); target.addIllegalOp<AtenReshapeOp>();
target.addIllegalOp<Aten_SoftmaxBackwardDataOp>(); target.addIllegalOp<Aten_SoftmaxBackwardDataOp>();
target.addIllegalOp<AtenTanhBackwardOp>(); target.addIllegalOp<AtenTanhBackwardOp>();
target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<AtenAddmmOp>(); target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>(); target.addIllegalOp<AtenMeanOp>();
target.addIllegalOp<AtenMeanDimOp>(); target.addIllegalOp<AtenMeanDimOp>();

View File

@ -554,6 +554,9 @@ STABLEHLO_PASS_SET = {
"EmptyLikeModule_int", "EmptyLikeModule_int",
"ExpandAsIntModule_basic", "ExpandAsIntModule_basic",
"ExpandModule_basic", "ExpandModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticContractRhsModule_basic",
"Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat32_basic",
"Fill_TensorFloat64WithFloat64_basic", "Fill_TensorFloat64WithFloat64_basic",
"Fill_TensorFloat64WithInt64_basic", "Fill_TensorFloat64WithInt64_basic",
@ -1020,6 +1023,9 @@ TOSA_PASS_SET = {
"RsubFloatModule_basic", "RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic", "RsubFloatModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic", "RsubInt0d_NumToTensor_Module_basic",
"EinsumStaticModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticContractRhsModule_basic",
"ElementwiseBitwiseAndModule_basic", "ElementwiseBitwiseAndModule_basic",
"ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt32Module_basic",

View File

@ -3684,6 +3684,19 @@ def atencat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0)
dtypes.append(tensor_dtype) dtypes.append(tensor_dtype)
return promote_dtypes(ranks, dtypes) return promote_dtypes(ranks, dtypes)
@check_dtype_function(
[Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32),
TensorOfShape(1, dtype=torch.int32)]),])
def ateneinsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int:
ranks: List[Optional[int]] = []
dtypes: List[int] = []
assert len(tensors_rank_dtype) != 0
for tensor_rank_dtype in tensors_rank_dtype:
tensor_rank, tensor_dtype = tensor_rank_dtype
ranks.append(tensor_rank)
dtypes.append(tensor_dtype)
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.int64 return torch.int64

View File

@ -566,6 +566,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)")
emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)")
emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)")

View File

@ -1044,3 +1044,59 @@ class UnflattenIntNegativeOneSizeStaticModule(torch.nn.Module):
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule()) @register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule())
def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils): def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 12, 3)) module.forward(tu.rand(5, 12, 3))
# ==============================================================================
class EinsumStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 2, 4], torch.float32, True),
([5, 4, 6], torch.float32, True),
([3, 7, 6], torch.float32, True),
])
def forward(self, tensor1, tensor2, tensor3):
return torch.ops.aten.einsum('bqe,ked,btd->bqtk', [tensor1, tensor2, tensor3])
@register_test_case(module_factory=lambda: EinsumStaticModule())
def EinsumStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2, 4), tu.rand(5, 4, 6), tu.rand(3, 7, 6))
class EinsumStaticFourDimensionModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5, 6], torch.float32, True),
([3, 7, 5, 6], torch.float32, True),
])
def forward(self, tensor1, tensor2):
return torch.ops.aten.einsum('blhd,bshd->blhs', [tensor1, tensor2])
@register_test_case(module_factory=lambda: EinsumStaticFourDimensionModule())
def EinsumStaticFourDimensionModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, 6), tu.rand(3, 7, 5, 6))
class EinsumStaticContractRhsModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5], torch.float32, True),
([4, 5], torch.float32, True),
])
def forward(self, tensor1, tensor2):
return torch.ops.aten.einsum('abc,bc->a', [tensor1, tensor2])
@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))