support `aten._trilinear` and improve `einsum` decomposition (#3784)

# Tracking
[Issue](https://github.com/nod-ai/SHARK-ModelDev/issues/848)
[TorchToLinalg Op
Support](https://github.com/nod-ai/SHARK-ModelDev/issues/347)

# Description

Aten_TrilinearOp is an implementation of a "trilinear einstein sum".
Essentially, just an einsum across 3 tensors.

There are a few inputs:
## Tensor Inputs
- i1, i2, i3 - The three input tensors for the _trilinear op.
## Expands 
These inputs allow you to unsqueeze an input tensor at the specified
dims as a pre-processing step to make the shapes compatible for the rest
of the op:
- expand1: List[int], expand2: List[int], expand3: List[int]

## sumdim
- sumdim: List[int] - After applying element wise multiplication, the
values in sumdim denote where to collapse a dimension by summing over it

## unroll_dim
- unroll_dim: int - In the PyTorch implementation, this specifies a
dimension where you could slice the input tensors, multiply and sum
them, then concatenate the results in an output tensor. This complicates
the implementation significantly, but doesn't change the result, so I
opted against it. Along with that, a previously accepted path for
solving this involved reusing the AtenEinsumOp, which also would also
ignore this input.


# Solution

After trying a bunch of more complicated approaches for it, this op
actually ended up being quite simple: [See
_trilinear](https://dev-discuss.pytorch.org/t/defining-the-core-aten-opset/1464)

`_trilinear = (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) *
i3.unsqueeze(expand3)).sum(sumdim)`

Wish I saw this earlier, but watcha gonna do: 🙃

## Not Reusing AtenEinsumOp
Frankly, I found multiple cases where valid inputs would have numerical
mismatches for EinsumOp, even when running tests against EinsumOp
directly. I think it has something to do with the singleton dimensions.
Will need to look into this further, but once I realized the simplified
approach, it appeared to be more reliable and much simpler.

Either way (credit to @zjgarvey), there are improvements to the einsum
op here. When I was originally trying to use the op, intermediate
tensors were being flattened properly, but then its 0th dimension was
being cast from a static dim to a dynamic dim due to integers not
folding correctly in the MLIR. Figured it's worth keeping these
improvements for future reusers of EinsumOp.

# The zero'd out dim "bug"

For some reason, if you specify a dimension in all `expands`,

```i.e. 
[expand1=[0], expand2=[0], expand3=[0]],
[expand1=[1], expand2=[1], expand3=[1]]
```

The _trilinear op would specify `0` for that dimension in the output
shape, unless it was also included in `sumdim`. This goes against the
implementation of torch.einsum:

```
>>> a, b, c = [torch.rand(1, 3, 3, 3) for i in range(3)] # Simulate expand at dim=0 for all input tensors
>>> torch.einsum('abcd,abcd,abcd->abcd', a, b, c).shape
torch.Size([1, 3, 3, 3])
```

And is just straight up incorrect mathematically. I considered
"replacing" singleton dims with zeroed out dims, but that seemed like
carrying over a bug. Instead, I included a test for the case, verified
that the singleton dimensions were handled the way that torch.einsum
handles it, instead of torch._trilinear, and xfailed it with a note as
to why.
pull/3820/head
Stephen Baione 2024-10-31 14:30:40 -05:00 committed by GitHub
parent 8f52f5a4ed
commit 9c1e3b8154
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 555 additions and 16 deletions

View File

@ -14248,6 +14248,36 @@ def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [
}]; }];
} }
def Torch_Aten_TrilinearOp : Torch_Op<"aten._trilinear", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$i1,
AnyTorchTensorType:$i2,
AnyTorchTensorType:$i3,
AnyTorchListOfTorchIntType:$expand1,
AnyTorchListOfTorchIntType:$expand2,
AnyTorchListOfTorchIntType:$expand3,
AnyTorchListOfTorchIntType:$sumdim,
Torch_IntType:$unroll_dim
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_TrilinearOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void Aten_TrilinearOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
}
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -8864,6 +8864,112 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten._trilinear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.list<int> {\n"
" %int3 = torch.constant.int 3\n"
" %int-1 = torch.constant.int -1\n"
" %str = torch.constant.str \"AssertionError: number of dimensions must match\"\n"
" %str_0 = torch.constant.str \"expand dimension {} is out of bounds for input of shape {}\"\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str_1 = torch.constant.str \"AssertionError: \"\n"
" %str_2 = torch.constant.str \"unroll_dim must be in [0, {}]\"\n"
" %false = torch.constant.bool false\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.len.t %arg3 : !torch.list<int> -> !torch.int\n"
" %2 = torch.aten.add.int %0, %1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.ge.int %arg7, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %4 = torch.prim.If %3 -> (!torch.bool) {\n"
" %23 = torch.aten.lt.int %arg7, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %23 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %23 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %24 = torch.aten.format(%str_2, %23) : !torch.str, !torch.int -> !torch.str\n"
" %25 = torch.aten.add.str %str_1, %24 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %25, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = call @__torch__.torch.jit._shape_functions._copy(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %6 = call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" %7 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
" %8 = torch.prim.ListConstruct %5, %6, %7 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
" %9 = torch.prim.ListConstruct %arg3, %arg4, %arg5 : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<list<int>>\n"
" torch.prim.Loop %int3, %true, init() {\n"
" ^bb0(%arg8: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %9, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %24 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %25 = torch.aten.len.t %24 : !torch.list<int> -> !torch.int\n"
" %26 = torch.aten.len.t %23 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %26, %true, init() {\n"
" ^bb0(%arg9: !torch.int):\n"
" %27 = torch.aten.__getitem__.t %23, %arg9 : !torch.list<int>, !torch.int -> !torch.int\n"
" %28 = torch.aten.le.int %27, %25 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %28 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" %30 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" %31 = torch.aten.format(%str_0, %27, %30) : !torch.str, !torch.int, !torch.list<int> -> !torch.str\n"
" %32 = torch.aten.add.str %str_1, %31 : !torch.str, !torch.str -> !torch.str\n"
" torch.prim.RaiseException %32, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %29 = torch.aten.__getitem__.t %8, %arg8 : !torch.list<list<int>>, !torch.int -> !torch.list<int>\n"
" torch.aten.insert.t %29, %27, %int1 : !torch.list<int>, !torch.int, !torch.int\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %10 = torch.aten.len.t %5 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %12 = torch.aten.eq.int %10, %11 : !torch.int, !torch.int -> !torch.bool\n"
" %13 = torch.prim.If %12 -> (!torch.bool) {\n"
" %23 = torch.aten.len.t %6 : !torch.list<int> -> !torch.int\n"
" %24 = torch.aten.len.t %7 : !torch.list<int> -> !torch.int\n"
" %25 = torch.aten.eq.int %23, %24 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %25 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" torch.prim.If %13 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list<int>, !torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list<bool>\n"
" %16 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list<bool>, !torch.int) -> !torch.list<bool> \n"
" %18 = torch.aten.len.t %arg6 : !torch.list<int> -> !torch.int\n"
" torch.prim.Loop %18, %true, init() {\n"
" ^bb0(%arg8: !torch.int):\n"
" %23 = torch.aten.__getitem__.t %arg6, %arg8 : !torch.list<int>, !torch.int -> !torch.int\n"
" %24 = torch.aten._set_item.t %17, %23, %true : !torch.list<bool>, !torch.int, !torch.bool -> !torch.list<bool>\n"
" torch.prim.Loop.condition %true, iter()\n"
" } : (!torch.int, !torch.bool) -> ()\n"
" %19 = torch.aten.len.t %14 : !torch.list<int> -> !torch.int\n"
" %20 = torch.aten.sub.int %19, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.__range_length %20, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %22 = torch.prim.Loop %21, %true, init(%14) {\n"
" ^bb0(%arg8: !torch.int, %arg9: !torch.list<int>):\n"
" %23 = torch.aten.__derive_index %arg8, %20, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n"
" %24 = torch.aten.__getitem__.t %17, %23 : !torch.list<bool>, !torch.int -> !torch.bool\n"
" %25 = torch.prim.If %24 -> (!torch.list<int>) {\n"
" %26 = func.call @__torch__.torch.jit._shape_functions._reduce_along_dim(%arg9, %23, %false) : (!torch.list<int>, !torch.int, !torch.bool) -> !torch.list<int>\n"
" torch.prim.If.yield %26 : !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %arg9 : !torch.list<int>\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%25 : !torch.list<int>)\n"
" } : (!torch.int, !torch.bool, !torch.list<int>) -> !torch.list<int>\n"
" return %22 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<list<int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>, %arg7: !torch.bool) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n" " %int-1 = torch.constant.int -1\n"
" %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n" " %0 = torch.aten.__getitem__.t %arg2, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
@ -15294,6 +15400,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._trilinear\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.list<int>, %arg7: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.cat\"(%arg0: !torch.list<tuple<int, int>>, %arg1: !torch.int) -> !torch.int {\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"

View File

@ -9,6 +9,7 @@
#include "PassDetail.h" #include "PassDetail.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@ -399,9 +400,9 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
auto inputType = cast<ValueTensorType>(input.getType()); auto inputType = cast<ValueTensorType>(input.getType());
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength + auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
reduceDimsLength; reduceDimsLength;
SmallVector<Value> inputShapeTensor; SmallVector<OpFoldResult> inputShapeTensor;
for (auto i = 0; i < inputRank; ++i) { for (auto i = 0; i < inputRank; ++i) {
inputShapeTensor.emplace_back(rewriter.create<AtenSizeIntOp>( inputShapeTensor.emplace_back(rewriter.createOrFold<AtenSizeIntOp>(
loc, input, loc, input,
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(i)))); rewriter.getI64IntegerAttr(i))));
@ -412,13 +413,23 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1)); rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto dimOffset = 0; auto dimOffset = 0;
auto appendDims = [&](int64_t dimLength) { auto materializeIntFold = [&](OpFoldResult thing) {
Value prod = constOne; if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
for (auto i = 0; i < dimLength; ++i) { Value result = rewriter.create<Torch::ConstantIntOp>(
prod = rewriter.create<AtenMulIntOp>(loc, prod, loc, cast<mlir::IntegerAttr>(attr));
inputShapeTensor[i + dimOffset]); return result;
} }
outShapeTensor.emplace_back(prod); return cast<mlir::Value>(thing);
};
auto appendDims = [&](int64_t dimLength) {
OpFoldResult prod = getAsOpFoldResult(constOne);
for (auto i = 0; i < dimLength; ++i) {
prod = rewriter.createOrFold<AtenMulIntOp>(
loc, materializeIntFold(prod),
materializeIntFold(inputShapeTensor[i + dimOffset]));
}
outShapeTensor.emplace_back(materializeIntFold(prod));
dimOffset += dimLength; dimOffset += dimLength;
}; };
@ -570,21 +581,32 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
: rhsType.getOptionalDtype(); : rhsType.getOptionalDtype();
auto materializeIntFold = [&](OpFoldResult thing) {
if (auto attr = dyn_cast<mlir::Attribute>(thing)) {
Value result = rewriter.create<Torch::ConstantIntOp>(
loc, cast<mlir::IntegerAttr>(attr));
return result;
}
return cast<mlir::Value>(thing);
};
llvm::SmallDenseMap<char, Value> lhsDimShapeMap; llvm::SmallDenseMap<char, Value> lhsDimShapeMap;
for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { for (size_t idx = 0; idx < lhsTokens.size(); ++idx) {
char d = lhsTokens[idx]; char d = lhsTokens[idx];
lhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>( OpFoldResult lhsFold = rewriter.createOrFold<AtenSizeIntOp>(
loc, lhs, loc, lhs,
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx))); rewriter.getI64IntegerAttr(idx)));
lhsDimShapeMap[d] = materializeIntFold(lhsFold);
} }
llvm::SmallDenseMap<char, Value> rhsDimShapeMap; llvm::SmallDenseMap<char, Value> rhsDimShapeMap;
for (size_t idx = 0; idx < rhsTokens.size(); ++idx) { for (size_t idx = 0; idx < rhsTokens.size(); ++idx) {
char d = rhsTokens[idx]; char d = rhsTokens[idx];
rhsDimShapeMap[d] = rewriter.create<AtenSizeIntOp>( OpFoldResult rhsFold = rewriter.createOrFold<AtenSizeIntOp>(
loc, rhs, loc, rhs,
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.create<Torch::ConstantIntOp>(loc,
rewriter.getI64IntegerAttr(idx))); rewriter.getI64IntegerAttr(idx)));
rhsDimShapeMap[d] = materializeIntFold(rhsFold);
} }
// parse batch, contracting, other, reduce dims of lhs and rhs // parse batch, contracting, other, reduce dims of lhs and rhs
@ -604,8 +626,9 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
bool lhsContains = lhsDimShapeMap.count(d) > 0; bool lhsContains = lhsDimShapeMap.count(d) > 0;
bool rhsContains = rhsDimShapeMap.count(d) > 0; bool rhsContains = rhsDimShapeMap.count(d) > 0;
if (lhsContains && rhsContains) { if (lhsContains && rhsContains) {
outDimShapeMap[d] = rewriter.create<Torch::PrimMaxIntOp>( OpFoldResult out = rewriter.createOrFold<Torch::PrimMaxIntOp>(
loc, lhsDimShapeMap[d], rhsDimShapeMap[d]); loc, lhsDimShapeMap[d], rhsDimShapeMap[d]);
outDimShapeMap[d] = materializeIntFold(out);
} else if (lhsContains) { } else if (lhsContains) {
outDimShapeMap[d] = lhsDimShapeMap[d]; outDimShapeMap[d] = lhsDimShapeMap[d];
} else if (rhsContains) { } else if (rhsContains) {
@ -1973,6 +1996,125 @@ public:
}; };
} // namespace } // namespace
namespace {
// Trilinear einstein sum, decomposed to:
// (i1.unsqueeze(expand1) * i2.unsqueeze(expand2) * i3.unsqueeze(expand3))
// .sum(sumdim)
// The unrollDim operand does not impact the output of the operation, so
// it is ignored.
class DecomposeAten_TrilinearOp : public OpRewritePattern<Aten_TrilinearOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_TrilinearOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input1 = op.getI1();
Value input2 = op.getI2();
Value input3 = op.getI3();
// Expansions
SmallVector<int64_t> expand1;
SmallVector<int64_t> expand2;
SmallVector<int64_t> expand3;
if (!matchPattern(op.getExpand1(), m_TorchListOfConstantInts(expand1))) {
return rewriter.notifyMatchFailure(op, "expand1 should be constant");
}
if (!matchPattern(op.getExpand2(), m_TorchListOfConstantInts(expand2))) {
return rewriter.notifyMatchFailure(op, "expand2 should be constant");
}
if (!matchPattern(op.getExpand3(), m_TorchListOfConstantInts(expand3))) {
return rewriter.notifyMatchFailure(op, "expand3 should be constant");
}
SmallVector<int64_t> sumDim;
if (!matchPattern(op.getSumdim(), m_TorchListOfConstantInts(sumDim))) {
return rewriter.notifyMatchFailure(op, "sumDim should be constant");
}
// Check if there are any dimensions that intersect between expand1,
// expand2, and expand3.
int64_t totalDims =
cast<BaseTensorType>(input1.getType()).getSizes().size() +
expand1.size();
if (sharedExpandDims(totalDims, expand1, expand2, expand3, sumDim)) {
// pytorch issue filed: https://github.com/pytorch/pytorch/issues/138353
// TODO: Remove warning when issue gets resolved.
op->emitWarning("aten::_trilinear implementation in this case is "
"non-functional (returns an empty dimension). We will "
"intentionally deviate from this behavior.");
}
// Apply unsqueeze to respective input tensors at the specified dimensions
SmallVector<int64_t> sortedExpand1 = expand1;
std::sort(sortedExpand1.begin(), sortedExpand1.end());
for (auto expand : sortedExpand1) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input1 = *unsqueezeTensor(rewriter, op, input1, expandDim);
}
SmallVector<int64_t> sortedExpand2 = expand2;
std::sort(sortedExpand2.begin(), sortedExpand2.end());
for (auto expand : sortedExpand2) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input2 = *unsqueezeTensor(rewriter, op, input2, expandDim);
}
SmallVector<int64_t> sortedExpand3 = expand3;
std::sort(sortedExpand3.begin(), sortedExpand3.end());
for (auto expand : sortedExpand3) {
Value expandDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(expand));
input3 = *unsqueezeTensor(rewriter, op, input3, expandDim);
}
// Apply multiplication operation.
auto mul1 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), input1, input2);
auto mul2 =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), mul1, input3);
// Apply sum operation.
// Parse sumDim in descending order to avoid any issues with the
// dimensions being removed.
Value result = mul2;
SmallVector<int64_t> sortedSumDims = sumDim;
std::sort(sortedSumDims.rbegin(), sortedSumDims.rend());
for (int64_t dim : sortedSumDims) {
Value dimValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dim));
result =
createSumAlongDimension(rewriter, loc, op, result, dimValue, false);
}
rewriter.replaceOp(op, result);
return success();
}
private:
// Determine if there are any dimensions that intersect between expand1,
// expand2, and expand3.
bool sharedExpandDims(const int64_t &totalDims,
const SmallVector<int64_t> &expand1,
const SmallVector<int64_t> &expand2,
const SmallVector<int64_t> &expand3,
const SmallVector<int64_t> &sumDim) const {
for (int64_t i = 0; i < totalDims; ++i) {
if (!contains(sumDim, i) && contains(expand1, i) &&
contains(expand2, i) && contains(expand3, i)) {
return true;
}
}
return false;
}
bool contains(const SmallVector<int64_t> &vec, int64_t value) const {
return std::find(vec.begin(), vec.end(), value) != vec.end();
}
};
} // namespace
namespace { namespace {
// Calculate the trace of the input tensor as the sum over its diagonal // Calculate the trace of the input tensor as the sum over its diagonal
// elements. This computation is performed as: // elements. This computation is performed as:
@ -10078,6 +10220,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast1dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenAtleast2dOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_TrilinearOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenHardswishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenSoftplusOp>(patterns);

View File

@ -400,6 +400,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenAtleast1dOp>(); target.addIllegalOp<AtenAtleast1dOp>();
target.addIllegalOp<AtenAtleast2dOp>(); target.addIllegalOp<AtenAtleast2dOp>();
target.addIllegalOp<AtenEinsumOp>(); target.addIllegalOp<AtenEinsumOp>();
target.addIllegalOp<Aten_TrilinearOp>();
target.addIllegalOp<AtenTraceOp>(); target.addIllegalOp<AtenTraceOp>();
target.addIllegalOp<AtenAddmmOp>(); target.addIllegalOp<AtenAddmmOp>();
target.addIllegalOp<AtenMeanOp>(); target.addIllegalOp<AtenMeanOp>();

View File

@ -29,6 +29,10 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"DeformConv2D_basic", "DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic", "ReduceAnyDimFloatModule_basic",
"UnfoldModule_basic", "UnfoldModule_basic",
# _trilinear is an implementation of einsum, but sets dimensions to zero
# if a dimension is specified in all expand lists, and not in sumdim list.
# This is a bug in the implementation of _trilinear in PyTorch.
"Aten_TrilinearModuleZerodDimBug_basic",
} }
if torch_version_for_comparison() < version.parse("2.5.0.dev"): if torch_version_for_comparison() < version.parse("2.5.0.dev"):
@ -394,6 +398,8 @@ FX_IMPORTER_XFAIL_SET = {
"AtenIntBoolOpModule_basic", "AtenIntBoolOpModule_basic",
"AtenIntMM_basic", "AtenIntMM_basic",
"AtenItemFpOpModule_basic", "AtenItemFpOpModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"QuantizedReluInt32_basic", "QuantizedReluInt32_basic",
"QuantizedReluInt8_basic", "QuantizedReluInt8_basic",
"QuantizedReluUint8_basic", "QuantizedReluUint8_basic",
@ -532,6 +538,9 @@ FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
"_SoftmaxModule_basic", "_SoftmaxModule_basic",
"UpSampleNearest2dDynamicFactor_basic", "UpSampleNearest2dDynamicFactor_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage # torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic",
@ -645,6 +654,8 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
"Aten_EmbeddingBagExample_basic", "Aten_EmbeddingBagExample_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AvgPool2dDivisorOverrideModule_basic", "AvgPool2dDivisorOverrideModule_basic",
"BernoulliTensorModule_basic", "BernoulliTensorModule_basic",
"BincountMinlengthModule_basic", "BincountMinlengthModule_basic",
@ -928,11 +939,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"AtenItemIntOpModule_basic", "AtenItemIntOpModule_basic",
"CrossEntropyLossModule_basic", "CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic", "CrossEntropyLossNoReductionModule_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"ElementwiseExpm1IntModule_basic", "ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic", "ElementwiseExpm1Module_basic",
"InterpolateDynamicModule_sizes_nearest", "InterpolateDynamicModule_sizes_nearest",
@ -984,6 +990,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
# materialization callback produced value of incorrect type failed # materialization callback produced value of incorrect type failed
"ReduceMaxAlongDimUnsignedInt_basic", "ReduceMaxAlongDimUnsignedInt_basic",
"ReduceMinAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
# torch export: RuntimeError: cannot mutate tensors with frozen storage # torch export: RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluWithNoiseTrainModule_basic", "ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseRreluWithNoiseTrainStaticModule_basic",
@ -3275,6 +3284,12 @@ ONNX_XFAIL_SET = {
"Unfold_Module_Rank_Zero_Size_Zero_basic", "Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic", "Unfold_Module_Dynamic_basic",
"ViewDtypeStaticModule_basic", "ViewDtypeStaticModule_basic",
"Aten_TrilinearModule_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
} }
if torch_version_for_comparison() < version.parse("2.3.0.dev"): if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@ -4055,6 +4070,12 @@ ONNX_TOSA_XFAIL_SET = {
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
"AtenTopKModule_basic", "AtenTopKModule_basic",
"AtenTopKSmallestModule_basic", "AtenTopKSmallestModule_basic",
"Aten_TrilinearModule_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AtenTrilModule_basic", "AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic", "AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic", "AtenTrilWithPosDiagonalModule_basic",

View File

@ -1295,6 +1295,44 @@ def atenunflattenint〡shape(self: List[int], dim: int, sizes: List[int])
def atenlinear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]: def atenlinear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias) return upstream_shape_functions.linear(input, weight, bias)
@check_shape_function([
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [], [], [], [], 0), # Basic case
Invocation(TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), TensorOfShape(4, 5, 6), [1], [0], [0], [], 2), # Expansions w/ Non-Zero unroll_dim
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [1, 2], [1, 2], [1, 2], 0), # Multiple expansions
Invocation(TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), TensorOfShape(3, 3, 3), [1, 2], [2, 1], [1, 2], [1, 2], 0), # Unordered expansion
ErrorInvocation(TensorOfShape(4, 5, 1), TensorOfShape(4, 5, 3), TensorOfShape(1, 5, 3), [], [], [0], [2], 0), # Num dimensions don't match
])
def aten_trilinear〡shape(i1: List[int], i2: List[int], i3: List[int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> List[int]:
total_dims = len(i1) + len(expand1)
assert unroll_dim >= 0 and unroll_dim < total_dims, f"unroll_dim must be in [0, {total_dims - 1}]"
i1_copy = upstream_shape_functions._copy(i1)
i2_copy = upstream_shape_functions._copy(i2)
i3_copy = upstream_shape_functions._copy(i3)
# Expand dimensions based on args
inputs = [i1_copy, i2_copy, i3_copy]
expands = [expand1, expand2, expand3]
for index, expand in enumerate(expands):
size = len(inputs[index])
for dim in expand:
assert dim <= size, f"expand dimension {dim} is out of bounds for input of shape {inputs[index]}"
inputs[index].insert(dim, 1)
assert len(i1_copy) == len(i2_copy) == len(i3_copy), "number of dimensions must match"
output_shape = upstream_shape_functions.broadcast_three(i1_copy, i2_copy, i3_copy)
sumdim_bools = [False] * len(output_shape)
for dim in sumdim:
sumdim_bools[dim] = True
for i in range(len(output_shape) - 1, -1, -1):
if sumdim_bools[i]:
output_shape = upstream_shape_functions._reduce_along_dim(output_shape, i, False)
return output_shape
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape Invocation(TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4), TensorOfShape(3, 2, 8, 4)), # Same shape
Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape Invocation(TensorOfShape(3, 2, 16, 8), TensorOfShape(3, 2, 8, 8), TensorOfShape(3, 2, 8, 4)), # Different shape
@ -5388,6 +5426,21 @@ def atenlinear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype:
promoted_dtype = promote_dtypes(ranks, dtypes) promoted_dtype = promote_dtypes(ranks, dtypes)
return promoted_dtype return promoted_dtype
@check_dtype_function(
_check_tensors_with_the_same_dtype(3, None, None, None, expand1 = [], expand2 = [], expand3 = [], sumdim = [], unroll_dim = 0),
)
def aten_trilinear〡dtype(i1_rank_dtype: Tuple[int, int], i2_rank_dtype: Tuple[int, int], i3_rank_dtype: Tuple[int, int], expand1: List[int], expand2: List[int], expand3: List[int], sumdim: List[int], unroll_dim: int = 1) -> int:
i1_rank, i1_dtype = i1_rank_dtype
i2_rank, i2_dtype = i2_rank_dtype
i3_rank, i3_dtype = i3_rank_dtype
ranks: List[Optional[int]] = [i1_rank, i2_rank, i3_rank]
dtypes = [i1_dtype, i2_dtype, i3_dtype]
return promote_dtypes(
ranks,
dtypes,
)
@check_dtype_function( @check_dtype_function(
[Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), [Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]),
Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]),

View File

@ -1022,6 +1022,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)" "aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?, bool) -> (Tensor)"
) )
emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)")
emit(
"aten::_trilinear : (Tensor, Tensor, Tensor, int[], int[], int[], int[], int) -> (Tensor)"
)
# Dict ops. # Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)

View File

@ -1674,6 +1674,9 @@ def Rot90NegativeEvenRotationsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 5, 1, 7, 3)) module.forward(tu.rand(6, 5, 1, 7, 3))
# ==============================================================================
class Unfold_Module(torch.nn.Module): class Unfold_Module(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1772,3 +1775,173 @@ class Unfold_Module_Dynamic(torch.nn.Module):
@register_test_case(module_factory=lambda: Unfold_Module_Dynamic()) @register_test_case(module_factory=lambda: Unfold_Module_Dynamic())
def Unfold_Module_Dynamic_basic(module, tu: TestUtils): def Unfold_Module_Dynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4, 4, 4)) module.forward(tu.rand(6, 4, 4, 4))
# ==============================================================================
class Aten_TrilinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 3, 3], torch.float32, True),
([3, 3, 3], torch.float32, True),
([3, 3, 3], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1, i2, i3, expand1=[], expand2=[], expand3=[], sumdim=[], unroll_dim=0
)
@register_test_case(module_factory=lambda: Aten_TrilinearModule())
def Aten_TrilinearModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3, 3), tu.rand(3, 3, 3), tu.rand(3, 3, 3))
class Aten_TrilinearModuleSumdims(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1, i2, i3, expand1=[1], expand2=[], expand3=[], sumdim=[0, 2], unroll_dim=0
)
@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumdims())
def Aten_TrilinearModuleSumdims_basic(module, tu: TestUtils):
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))
class Aten_TrilinearModuleSumAllDims(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1,
i2,
i3,
expand1=[1],
expand2=[],
expand3=[],
sumdim=[0, 1, 2],
unroll_dim=0,
)
@register_test_case(module_factory=lambda: Aten_TrilinearModuleSumAllDims())
def Aten_TrilinearModuleSumAllDims_basic(module, tu: TestUtils):
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))
class Aten_TrilinearModuleVaryingRanks(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
([6], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1,
i2,
i3,
expand1=[1],
expand2=[],
expand3=[0, 1],
sumdim=[0],
unroll_dim=0,
)
@register_test_case(module_factory=lambda: Aten_TrilinearModuleVaryingRanks())
def Aten_TrilinearModuleVaryingRanks_basic(module, tu: TestUtils):
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6))
class Aten_TrilinearModuleVaryingRanksUnorderedExpands(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
([6], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1,
i2,
i3,
expand1=[1],
expand2=[],
expand3=[1, 0],
sumdim=[2, 0],
unroll_dim=0,
)
@register_test_case(
module_factory=lambda: Aten_TrilinearModuleVaryingRanksUnorderedExpands()
)
def Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic(module, tu: TestUtils):
return module.forward(tu.rand(2, 6), tu.rand(2, 3, 6), tu.rand(6))
class Aten_TrilinearModuleZerodDimBug(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 3, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
([2, 3, 6], torch.float32, True),
]
)
def forward(self, i1, i2, i3):
return torch.ops.aten._trilinear(
i1, i2, i3, expand1=[0], expand2=[0], expand3=[0], sumdim=[2], unroll_dim=0
)
@register_test_case(module_factory=lambda: Aten_TrilinearModuleZerodDimBug())
def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils):
return module.forward(tu.rand(2, 3, 6), tu.rand(2, 3, 6), tu.rand(2, 3, 6))