mirror of https://github.com/llvm/torch-mlir
support `aten._trilinear` and improve `einsum` decomposition (#3784)
# Tracking [Issue](https://github.com/nod-ai/SHARK-ModelDev/issues/848) [TorchToLinalg Op Support](https://github.com/nod-ai/SHARK-ModelDev/issues/347) # Description Aten_TrilinearOp is an implementation of a "trilinear einstein sum". Essentially, just an einsum across 3 tensors. There are a few inputs: ## Tensor Inputs - i1, i2, i3 - The three input tensors for the _trilinear op. ## Expands These inputs allow you to unsqueeze an input tensor at the specified dims as a pre-processing step to make the shapes compatible for the rest of the op: - expand1: List[int], expand2: List[int], expand3: List[int] ## sumdim - sumdim: List[int] - After applying element wise multiplication, the values in sumdim denote where to collapse a dimension by summing over it ## unroll_dim - unroll_dim: int - In the PyTorch implementation, this specifies a dimension where you could slice the input tensors, multiply and sum them, then concatenate the results in an output tensor. This complicates the implementation significantly, but doesn't change the result, so I opted against it. Along with that, a previously accepted path for solving this involved reusing the AtenEinsumOp, which also would also ignore this input. # Solution After trying a bunch of more complicated approaches for it, this op actually ended up being quite simple: [See _trilinear](https://dev-discuss.pytorch.org/t/defining-the-core-aten-opset/1464) `_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3)).sum(sumdim)` Wish I saw this earlier, but watcha gonna do: 🙃 ## Not Reusing AtenEinsumOp Frankly, I found multiple cases where valid inputs would have numerical mismatches for EinsumOp, even when running tests against EinsumOp directly. I think it has something to do with the singleton dimensions. Will need to look into this further, but once I realized the simplified approach, it appeared to be more reliable and much simpler. Either way (credit to @zjgarvey), there are improvements to the einsum op here. When I was originally trying to use the op, intermediate tensors were being flattened properly, but then its 0th dimension was being cast from a static dim to a dynamic dim due to integers not folding correctly in the MLIR. Figured it's worth keeping these improvements for future reusers of EinsumOp. # The zero'd out dim "bug" For some reason, if you specify a dimension in all `expands`, ```i.e. [expand1=[0], expand2=[0], expand3=[0]], [expand1=[1], expand2=[1], expand3=[1]] ``` The _trilinear op would specify `0` for that dimension in the output shape, unless it was also included in `sumdim`. This goes against the implementation of torch.einsum: ``` >>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors >>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape torch.Size([1, 3, 3, 3]) ``` And is just straight up incorrect mathematically. I considered "replacing" singleton dims with zeroed out dims, but that seemed like carrying over a bug. Instead, I included a test for the case, verified that the singleton dimensions were handled the way that torch.einsum handles it, instead of torch._trilinear, and xfailed it with a note as to why.pull/3820/head
parent
8f52f5a4ed
commit
9c1e3b8154
|
@ -14248,6 +14248,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$i1,
|
||||
AnyTorchTensorType:$i2,
|
||||
AnyTorchTensorType:$i3,
|
||||
AnyTorchListOfTorchIntType:$expand1,
|
||||
AnyTorchListOfTorchIntType:$expand2,
|
||||
AnyTorchListOfTorchIntType:$expand3,
|
||||
AnyTorchListOfTorchIntType:$sumdim,
|
||||
Torch_IntType:$unroll_dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 8, 1);
|
||||
}
|
||||
void Aten_TrilinearOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 8, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -8864,6 +8864,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n"
|
||||
" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str_1 = torch.constant.str \"AssertionError: \"\n"
|
||||
" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n"
|
||||
" %false = torch.constant.bool false\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
|
||||
" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
|
||||
" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %23 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %4 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n"
|
||||
" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
|
||||
" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
|
||||
" torch.prim.Loop %int3, %true, init() {\n"
|
||||
" ^bb0(%arg8: !torch.int):\n"
|
||||
" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %25 = torch.aten.len.t %24 : !torch.list<int> -> !torch.int\n"
|
||||
" %26 = torch.aten.len.t %23 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %26, %true, init() {\n"
|
||||
" ^bb0(%arg9: !torch.int):\n"
|
||||
" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %28 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list<int> -> !torch.str\n"
|
||||
" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n"
|
||||
" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
|
||||
" torch.aten.insert.t %29, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %10 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int\n"
|
||||
" %11 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
|
||||
" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
|
||||
" %23 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
|
||||
" %24 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
|
||||
" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %25 : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %false : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %13 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
|
||||
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
|
||||
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
|
||||
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
|
||||
" torch.prim.Loop %18, %true, init() {\n"
|
||||
" ^bb0(%arg8: !torch.int):\n"
|
||||
" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list<bool>, !torch.int, !torch.bool -> !torch.list<bool>\n"
|
||||
" torch.prim.Loop.condition %true, iter()\n"
|
||||
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||
" %19 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
|
||||
" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %22 = torch.prim.Loop %21, %true, init(%14) {\n"
|
||||
" ^bb0(%arg8: !torch.int, %arg9: !torch.list<int>):\n"
|
||||
" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
|
||||
" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list<bool>, !torch.int -> !torch.bool\n"
|
||||
" %25 = torch.prim.If %24 -> (!torch.list<int>) {\n"
|
||||
" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %26 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" torch.prim.If.yield %arg9 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" torch.prim.Loop.condition %true, iter(%25 : !torch.list<int>)\n"
|
||||
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %22 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
|
@ -15294,6 +15400,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %5 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %none = torch.constant.none\n"
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -399,9 +400,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
auto inputType = cast<ValueTensorType>(input.getType());
|
||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||
reduceDimsLength;
|
||||
SmallVector<Value> inputShapeTensor;
|
||||
SmallVector<OpFoldResult> inputShapeTensor;
|
||||
for (auto i = 0; i < inputRank; ++i) {
|
||||
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>(
|
||||
inputShapeTensor.emplace_back(rewriter.createOrFold<AtenSizeIntOp>(
|
||||
loc, input,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(i))));
|
||||
|
@ -412,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
|||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto dimOffset = 0;
|
||||
|
||||
auto appendDims = [&](int64_t dimLength) {
|
||||
Value prod = constOne;
|
||||
for (auto i = 0; i < dimLength; ++i) {
|
||||
prod = rewriter.create<AtenMulIntOp>(loc, prod,
|
||||
inputShapeTensor[i + dimOffset]);
|
||||
auto materializeIntFold = [&](OpFoldResult thing) {
|
||||
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
|
||||
Value result = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, cast<mlir::IntegerAttr>(attr));
|
||||
return result;
|
||||
}
|
||||
outShapeTensor.emplace_back(prod);
|
||||
return cast<mlir::Value>(thing);
|
||||
};
|
||||
|
||||
auto appendDims = [&](int64_t dimLength) {
|
||||
OpFoldResult prod = getAsOpFoldResult(constOne);
|
||||
for (auto i = 0; i < dimLength; ++i) {
|
||||
prod = rewriter.createOrFold<AtenMulIntOp>(
|
||||
loc, materializeIntFold(prod),
|
||||
materializeIntFold(inputShapeTensor[i + dimOffset]));
|
||||
}
|
||||
outShapeTensor.emplace_back(materializeIntFold(prod));
|
||||
dimOffset += dimLength;
|
||||
};
|
||||
|
||||
|
@ -570,21 +581,32 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||
: rhsType.getOptionalDtype();
|
||||
|
||||
auto materializeIntFold = [&](OpFoldResult thing) {
|
||||
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
|
||||
Value result = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, cast<mlir::IntegerAttr>(attr));
|
||||
return result;
|
||||
}
|
||||
return cast<mlir::Value>(thing);
|
||||
};
|
||||
|
||||
llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
|
||||
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
|
||||
char d = lhsTokens[idx];
|
||||
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||
OpFoldResult lhsFold = rewriter.createOrFold<AtenSizeIntOp>(
|
||||
loc, lhs,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(idx)));
|
||||
lhsDimShapeMap[d] = materializeIntFold(lhsFold);
|
||||
}
|
||||
llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
|
||||
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
|
||||
char d = rhsTokens[idx];
|
||||
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>(
|
||||
OpFoldResult rhsFold = rewriter.createOrFold<AtenSizeIntOp>(
|
||||
loc, rhs,
|
||||
rewriter.create<Torch::ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(idx)));
|
||||
rhsDimShapeMap[d] = materializeIntFold(rhsFold);
|
||||
}
|
||||
|
||||
// parse batch, contracting, other, reduce dims of lhs and rhs
|
||||
|
@ -604,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
|||
bool lhsContains = lhsDimShapeMap.count(d) > 0;
|
||||
bool rhsContains = rhsDimShapeMap.count(d) > 0;
|
||||
if (lhsContains && rhsContains) {
|
||||
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>(
|
||||
OpFoldResult out = rewriter.createOrFold<Torch::PrimMaxIntOp>(
|
||||
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
|
||||
outDimShapeMap[d] = materializeIntFold(out);
|
||||
} else if (lhsContains) {
|
||||
outDimShapeMap[d] = lhsDimShapeMap[d];
|
||||
} else if (rhsContains) {
|
||||
|
@ -1973,6 +1996,125 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Trilinear einstein sum, decomposed to:
|
||||
// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
|
||||
// .sum(sumdim)
|
||||
// The unrollDim operand does not impact the output of the operation, so
|
||||
// it is ignored.
|
||||
|
||||
class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_TrilinearOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
|
||||
Value input1 = op.getI1();
|
||||
Value input2 = op.getI2();
|
||||
Value input3 = op.getI3();
|
||||
|
||||
// Expansions
|
||||
SmallVector<int64_t> expand1;
|
||||
SmallVector<int64_t> expand2;
|
||||
SmallVector<int64_t> expand3;
|
||||
if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) {
|
||||
return rewriter.notifyMatchFailure(op, "expand1 should be constant");
|
||||
}
|
||||
if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) {
|
||||
return rewriter.notifyMatchFailure(op, "expand2 should be constant");
|
||||
}
|
||||
if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) {
|
||||
return rewriter.notifyMatchFailure(op, "expand3 should be constant");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> sumDim;
|
||||
if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) {
|
||||
return rewriter.notifyMatchFailure(op, "sumDim should be constant");
|
||||
}
|
||||
|
||||
// Check if there are any dimensions that intersect between expand1,
|
||||
// expand2, and expand3.
|
||||
int64_t totalDims =
|
||||
cast<BaseTensorType>(input1.getType()).getSizes().size() +
|
||||
expand1.size();
|
||||
if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) {
|
||||
// pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
|
||||
// TODO: Remove warning when issue gets resolved.
|
||||
op->emitWarning("aten::_trilinear implementation in this case is "
|
||||
"non-functional (returns an empty dimension). We will "
|
||||
"intentionally deviate from this behavior.");
|
||||
}
|
||||
|
||||
// Apply unsqueeze to respective input tensors at the specified dimensions
|
||||
SmallVector<int64_t> sortedExpand1 = expand1;
|
||||
std::sort(sortedExpand1.begin(), sortedExpand1.end());
|
||||
for (auto expand : sortedExpand1) {
|
||||
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(expand));
|
||||
input1 = *unsqueezeTensor(rewriter, op, input1, expandDim);
|
||||
}
|
||||
SmallVector<int64_t> sortedExpand2 = expand2;
|
||||
std::sort(sortedExpand2.begin(), sortedExpand2.end());
|
||||
for (auto expand : sortedExpand2) {
|
||||
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(expand));
|
||||
input2 = *unsqueezeTensor(rewriter, op, input2, expandDim);
|
||||
}
|
||||
SmallVector<int64_t> sortedExpand3 = expand3;
|
||||
std::sort(sortedExpand3.begin(), sortedExpand3.end());
|
||||
for (auto expand : sortedExpand3) {
|
||||
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(expand));
|
||||
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
|
||||
}
|
||||
|
||||
// Apply multiplication operation.
|
||||
auto mul1 =
|
||||
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
|
||||
auto mul2 =
|
||||
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);
|
||||
|
||||
// Apply sum operation.
|
||||
// Parse sumDim in descending order to avoid any issues with the
|
||||
// dimensions being removed.
|
||||
Value result = mul2;
|
||||
SmallVector<int64_t> sortedSumDims = sumDim;
|
||||
std::sort(sortedSumDims.rbegin(), sortedSumDims.rend());
|
||||
for (int64_t dim : sortedSumDims) {
|
||||
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(dim));
|
||||
result =
|
||||
createSumAlongDimension(rewriter, loc, op, result, dimValue, false);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// Determine if there are any dimensions that intersect between expand1,
|
||||
// expand2, and expand3.
|
||||
bool sharedExpandDims(const int64_t &totalDims,
|
||||
const SmallVector<int64_t> &expand1,
|
||||
const SmallVector<int64_t> &expand2,
|
||||
const SmallVector<int64_t> &expand3,
|
||||
const SmallVector<int64_t> &sumDim) const {
|
||||
for (int64_t i = 0; i < totalDims; ++i) {
|
||||
if (!contains(sumDim, i) && contains(expand1, i) &&
|
||||
contains(expand2, i) && contains(expand3, i)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool contains(const SmallVector<int64_t> &vec, int64_t value) const {
|
||||
return std::find(vec.begin(), vec.end(), value) != vec.end();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Calculate the trace of the input tensor as the sum over its diagonal
|
||||
// elements. This computation is performed as:
|
||||
|
@ -10078,6 +10220,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);
|
||||
|
|
|
@ -400,6 +400,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenAtleast1dOp>();
|
||||
target.addIllegalOp<AtenAtleast2dOp>();
|
||||
target.addIllegalOp<AtenEinsumOp>();
|
||||
target.addIllegalOp<Aten_TrilinearOp>();
|
||||
target.addIllegalOp<AtenTraceOp>();
|
||||
target.addIllegalOp<AtenAddmmOp>();
|
||||
target.addIllegalOp<AtenMeanOp>();
|
||||
|
|
|
@ -29,6 +29,10 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"DeformConv2D_basic",
|
||||
"ReduceAnyDimFloatModule_basic",
|
||||
"UnfoldModule_basic",
|
||||
# _trilinear is an implementation of einsum, but sets dimensions to zero
|
||||
# if a dimension is specified in all expand lists, and not in sumdim list.
|
||||
# This is a bug in the implementation of _trilinear in PyTorch.
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.5.0.dev"):
|
||||
|
@ -394,6 +398,8 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"AtenIntBoolOpModule_basic",
|
||||
"AtenIntMM_basic",
|
||||
"AtenItemFpOpModule_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
"QuantizedReluInt32_basic",
|
||||
"QuantizedReluInt8_basic",
|
||||
"QuantizedReluUint8_basic",
|
||||
|
@ -532,6 +538,9 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
|
|||
"_SoftmaxModule_basic",
|
||||
"UpSampleNearest2dDynamicFactor_basic",
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
# torch export: RuntimeError: cannot mutate tensors with frozen storage
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
|
@ -645,6 +654,8 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
"AvgPool2dDivisorOverrideModule_basic",
|
||||
"BernoulliTensorModule_basic",
|
||||
"BincountMinlengthModule_basic",
|
||||
|
@ -928,11 +939,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"AtenItemIntOpModule_basic",
|
||||
"CrossEntropyLossModule_basic",
|
||||
"CrossEntropyLossNoReductionModule_basic",
|
||||
"EinsumStaticContractRhsModule_basic",
|
||||
"EinsumStaticFourDimensionModule_basic",
|
||||
"EinsumStaticModule_basic",
|
||||
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
|
||||
"EinsumStaticWithEllipsisSlicingModule_basic",
|
||||
"ElementwiseExpm1IntModule_basic",
|
||||
"ElementwiseExpm1Module_basic",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
|
@ -984,6 +990,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
|||
# materialization callback produced value of incorrect type failed
|
||||
"ReduceMaxAlongDimUnsignedInt_basic",
|
||||
"ReduceMinAlongDimUnsignedInt_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
# torch export: RuntimeError: cannot mutate tensors with frozen storage
|
||||
"ElementwiseRreluWithNoiseTrainModule_basic",
|
||||
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
|
||||
|
@ -3275,6 +3284,12 @@ ONNX_XFAIL_SET = {
|
|||
"Unfold_Module_Rank_Zero_Size_Zero_basic",
|
||||
"Unfold_Module_Dynamic_basic",
|
||||
"ViewDtypeStaticModule_basic",
|
||||
"Aten_TrilinearModule_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
}
|
||||
|
||||
if torch_version_for_comparison() < version.parse("2.3.0.dev"):
|
||||
|
@ -4055,6 +4070,12 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"AtenSubFloatModule_basic",
|
||||
"AtenTopKModule_basic",
|
||||
"AtenTopKSmallestModule_basic",
|
||||
"Aten_TrilinearModule_basic",
|
||||
"Aten_TrilinearModuleSumdims_basic",
|
||||
"Aten_TrilinearModuleSumAllDims_basic",
|
||||
"Aten_TrilinearModuleVaryingRanks_basic",
|
||||
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
|
||||
"Aten_TrilinearModuleZerodDimBug_basic",
|
||||
"AtenTrilModule_basic",
|
||||
"AtenTrilWithNegDiagonalModule_basic",
|
||||
"AtenTrilWithPosDiagonalModule_basic",
|
||||
|
|
|
@ -1295,6 +1295,44 @@ def aten〇unflatten〇int〡shape(self: List[int], dim: int, sizes: List[int])
|
|||
def aten〇linear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
|
||||
return upstream_shape_functions.linear(input, weight, bias)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [], [], [], [], 0), # Basic case
|
||||
Invocation(TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), [1], [0], [0], [], 2), # Expansions w/ Non-Zero unroll_dim
|
||||
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [1, 2], [1, 2], [1, 2], 0), # Multiple expansions
|
||||
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [2, 1], [1, 2], [1, 2], 0), # Unordered expansion
|
||||
ErrorInvocation(TensorOfShape(4, 5, 1), TensorOfShape(4, 5, 3), TensorOfShape(1, 5, 3), [], [], [0], [2], 0), # Num dimensions don't match
|
||||
])
|
||||
def aten〇_trilinear〡shape(i1: List[int], i2: List[int], i3: List[int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> List[int]:
|
||||
total_dims = len(i1) + len(expand1)
|
||||
|
||||
assert unroll_dim >= 0 and unroll_dim < total_dims, f"unroll_dim must be in [0, {total_dims - 1}]"
|
||||
|
||||
i1_copy = upstream_shape_functions._copy(i1)
|
||||
i2_copy = upstream_shape_functions._copy(i2)
|
||||
i3_copy = upstream_shape_functions._copy(i3)
|
||||
|
||||
# Expand dimensions based on args
|
||||
inputs = [i1_copy, i2_copy, i3_copy]
|
||||
expands = [expand1, expand2, expand3]
|
||||
for index, expand in enumerate(expands):
|
||||
size = len(inputs[index])
|
||||
for dim in expand:
|
||||
assert dim <= size, f"expand dimension {dim} is out of bounds for input of shape {inputs[index]}"
|
||||
inputs[index].insert(dim, 1)
|
||||
|
||||
assert len(i1_copy) == len(i2_copy) == len(i3_copy), "number of dimensions must match"
|
||||
|
||||
output_shape = upstream_shape_functions.broadcast_three(i1_copy, i2_copy, i3_copy)
|
||||
sumdim_bools = [False] * len(output_shape)
|
||||
for dim in sumdim:
|
||||
sumdim_bools[dim] = True
|
||||
|
||||
for i in range(len(output_shape) - 1, -1, -1):
|
||||
if sumdim_bools[i]:
|
||||
output_shape = upstream_shape_functions._reduce_along_dim(output_shape, i, False)
|
||||
|
||||
return output_shape
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
|
||||
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
|
||||
|
@ -5388,6 +5426,21 @@ def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype:
|
|||
promoted_dtype = promote_dtypes(ranks, dtypes)
|
||||
return promoted_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(3, None, None, None, expand1 = [], expand2 = [], expand3 = [], sumdim = [], unroll_dim = 0),
|
||||
)
|
||||
def aten〇_trilinear〡dtype(i1_rank_dtype: Tuple[int, int], i2_rank_dtype: Tuple[int, int], i3_rank_dtype: Tuple[int, int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> int:
|
||||
i1_rank, i1_dtype = i1_rank_dtype
|
||||
i2_rank, i2_dtype = i2_rank_dtype
|
||||
i3_rank, i3_dtype = i3_rank_dtype
|
||||
|
||||
ranks: List[Optional[int]] = [i1_rank, i2_rank, i3_rank]
|
||||
dtypes = [i1_dtype, i2_dtype, i3_dtype]
|
||||
return promote_dtypes(
|
||||
ranks,
|
||||
dtypes,
|
||||
)
|
||||
|
||||
@check_dtype_function(
|
||||
[Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
|
||||
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),
|
||||
|
|
|
@ -1022,6 +1022,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
|
||||
)
|
||||
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
|
||||
emit(
|
||||
"aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"
|
||||
)
|
||||
|
||||
# Dict ops.
|
||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -1674,6 +1674,9 @@ def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
|
|||
module.forward(tu.rand(6, 5, 1, 7, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Unfold_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -1772,3 +1775,173 @@ class Unfold_Module_Dynamic(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
|
||||
def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(6, 4, 4, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class Aten_TrilinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3, 3, 3], torch.float32, True),
|
||||
([3, 3, 3], torch.float32, True),
|
||||
([3, 3, 3], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1, i2, i3, expand1=[], expand2=[], expand3=[], sumdim=[], unroll_dim=0
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_TrilinearModule())
|
||||
def Aten_TrilinearModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3), tu.rand(3, 3, 3))
|
||||
|
||||
|
||||
class Aten_TrilinearModuleSumdims(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1, i2, i3, expand1=[1], expand2=[], expand3=[], sumdim=[0, 2], unroll_dim=0
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumdims())
|
||||
def Aten_TrilinearModuleSumdims_basic(module, tu: TestUtils):
|
||||
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))
|
||||
|
||||
|
||||
class Aten_TrilinearModuleSumAllDims(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
expand1=[1],
|
||||
expand2=[],
|
||||
expand3=[],
|
||||
sumdim=[0, 1, 2],
|
||||
unroll_dim=0,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumAllDims())
|
||||
def Aten_TrilinearModuleSumAllDims_basic(module, tu: TestUtils):
|
||||
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))
|
||||
|
||||
|
||||
class Aten_TrilinearModuleVaryingRanks(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
expand1=[1],
|
||||
expand2=[],
|
||||
expand3=[0, 1],
|
||||
sumdim=[0],
|
||||
unroll_dim=0,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_TrilinearModuleVaryingRanks())
|
||||
def Aten_TrilinearModuleVaryingRanks_basic(module, tu: TestUtils):
|
||||
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6))
|
||||
|
||||
|
||||
class Aten_TrilinearModuleVaryingRanksUnorderedExpands(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
expand1=[1],
|
||||
expand2=[],
|
||||
expand3=[1, 0],
|
||||
sumdim=[2, 0],
|
||||
unroll_dim=0,
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: Aten_TrilinearModuleVaryingRanksUnorderedExpands()
|
||||
)
|
||||
def Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic(module, tu: TestUtils):
|
||||
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6))
|
||||
|
||||
|
||||
class Aten_TrilinearModuleZerodDimBug(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
([2, 3, 6], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, i1, i2, i3):
|
||||
return torch.ops.aten._trilinear(
|
||||
i1, i2, i3, expand1=[0], expand2=[0], expand3=[0], sumdim=[2], unroll_dim=0
|
||||
)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_TrilinearModuleZerodDimBug())
|
||||
def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils):
|
||||
return module.forward(tu.rand(2, 3, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))
|
||||
|
|
Loading…
Reference in New Issue