mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.linalg_cross (#2986)
Closes [nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497)pull/3026/head
parent
6fa21bd8b1
commit
524ff99216
|
@ -11732,6 +11732,32 @@ def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other,
|
||||||
|
Torch_IntType:$dim
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenLinalgCrossOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||||
|
}
|
||||||
|
void AtenLinalgCrossOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 3, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
let hasVerifier = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -4278,6 +4278,96 @@ LogicalResult AtenPermuteOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenLinalgCrossOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
LogicalResult AtenLinalgCrossOp::verify() {
|
||||||
|
|
||||||
|
auto selfType = getSelf().getType().cast<BaseTensorType>();
|
||||||
|
auto otherType = getOther().getType().cast<BaseTensorType>();
|
||||||
|
|
||||||
|
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
|
||||||
|
!otherType.hasSizes()) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
Type selfDtype = selfType.getDtype();
|
||||||
|
Type otherDtype = otherType.getDtype();
|
||||||
|
|
||||||
|
// the operation succeeds only if both inputs have the same dtype
|
||||||
|
if (selfDtype != otherDtype) {
|
||||||
|
return emitOpError("input tensors must have the same dtype, but got ")
|
||||||
|
<< selfDtype << " and " << otherDtype;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if any of the input tensors has torch.bool dtype.
|
||||||
|
// The operation does not support this type.
|
||||||
|
// The docs state that only float, double, cfloat and cdouble dtypes are
|
||||||
|
// supported, but, when testing, it fails only for boolean dtype. Update to
|
||||||
|
// fit the docs if necessary.
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.linalg.cross.html
|
||||||
|
if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) {
|
||||||
|
return emitOpError("input tensors must not have bool dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<int64_t> selfShape = selfType.getSizes();
|
||||||
|
ArrayRef<int64_t> otherShape = otherType.getSizes();
|
||||||
|
|
||||||
|
int64_t selfRank = selfShape.size();
|
||||||
|
int64_t otherRank = otherShape.size();
|
||||||
|
|
||||||
|
// check if both input tensors have the same number of dims
|
||||||
|
if (selfRank != otherRank) {
|
||||||
|
return emitOpError("input tensors must have the same number of dimensions, "
|
||||||
|
"but got ")
|
||||||
|
<< selfRank << " and " << otherRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert dim to an integer type
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if dim is in the correct range
|
||||||
|
if (dim >= selfRank || dim < -selfRank) {
|
||||||
|
return emitOpError("dim expected to be in rank of [")
|
||||||
|
<< -selfRank << ", " << selfRank - 1 << "], but got " << dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
// compensate for possible negative dim value
|
||||||
|
if (dim < 0) {
|
||||||
|
dim += selfRank;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the size of the dimensions specified by 'dim' is equal to 3
|
||||||
|
// (required by the operation)
|
||||||
|
if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) ||
|
||||||
|
(otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) {
|
||||||
|
return emitOpError("inputs dimension ")
|
||||||
|
<< dim << " must have length 3, but got " << selfShape[dim]
|
||||||
|
<< " and " << otherShape[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there is a disparity between dimension sizes.
|
||||||
|
// Dimensions at the same index must either have the same size,
|
||||||
|
// or one of them must be equal to 1.
|
||||||
|
int32_t i = 0;
|
||||||
|
for (auto [selfCurrent, otherCurrent] :
|
||||||
|
llvm::zip_equal(selfShape, otherShape)) {
|
||||||
|
if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) {
|
||||||
|
return emitOpError("the size of first tensor (")
|
||||||
|
<< selfCurrent << ") must match the size of second tensor ("
|
||||||
|
<< otherCurrent << ") at dimension " << i
|
||||||
|
<< " or one of them must be 1";
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// DtypeCalculateYieldDtypesOp
|
// DtypeCalculateYieldDtypesOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -6793,6 +6793,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.linalg_cross\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %str_0 = torch.constant.str \"the size of first tensor ({}) must match the size of second tensor ({}) at dimension {}\"\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str_1 = torch.constant.str \"AssertionError: inputs must have the same number of dimensions\"\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %1 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||||
|
" torch.prim.Loop %3, %true, init() {\n"
|
||||||
|
" ^bb0(%arg3: !torch.int):\n"
|
||||||
|
" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" %9 = torch.prim.If %8 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %10 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %11 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %9 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %12 = torch.aten.format(%str_0, %10, %11, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n"
|
||||||
|
" %13 = torch.aten.add.str %str, %12 : !torch.str, !torch.str -> !torch.str\n"
|
||||||
|
" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.Loop.condition %true, iter()\n"
|
||||||
|
" } : (!torch.int, !torch.bool) -> ()\n"
|
||||||
|
" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %4 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -10033,6 +10084,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_cross\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int) -> !torch.int {\n"
|
||||||
|
" %int11 = torch.constant.int 11\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
|
||||||
|
" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %3 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %4 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
|
||||||
|
" torch.prim.If %4 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %5 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||||
|
" return %6 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||||
" return %arg3 : !torch.int\n"
|
" return %arg3 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
|
|
@ -1823,6 +1823,117 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select,
|
||||||
|
// aten.add.Tensor and aten.mull.Tensor. See
|
||||||
|
// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70.
|
||||||
|
// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1):
|
||||||
|
// broadcast_shape = compute_broadcast_shape(self, other)
|
||||||
|
// a = torch.broadcast_to(self, broadcast_shape)
|
||||||
|
// b = torch.broadcast_to(other, broadcast_shape)
|
||||||
|
// idx = torch.arange(3)
|
||||||
|
// return a.index_select(dim, (idx + 1) % 3) *
|
||||||
|
// b.index_select(dim, (idx + 2) % 3) -
|
||||||
|
// a.index_select(dim, (idx + 2) % 3) *
|
||||||
|
// b.index_select(dim, (idx + 1) % 3)
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenLinalgCrossOp : public OpRewritePattern<AtenLinalgCrossOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenLinalgCrossOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
Value other = op.getOther();
|
||||||
|
Type opType = op.getType();
|
||||||
|
Value dim = op.getDim();
|
||||||
|
|
||||||
|
auto resType = self.getType().cast<BaseTensorType>();
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
Type dtype = resType.getDtype();
|
||||||
|
if (dtype.isa<mlir::ComplexType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "lowering of aten.linalg_cross for complex inputs dtype is "
|
||||||
|
"currently unimplemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate common shape for broadcast
|
||||||
|
SmallVector<int64_t> broadcastShape;
|
||||||
|
SmallVector<Value> broadcastShapeValue;
|
||||||
|
computeBroadcastShape(rewriter, loc, self, other, broadcastShape,
|
||||||
|
broadcastShapeValue);
|
||||||
|
|
||||||
|
Type broadcastType = ValueTensorType::get(
|
||||||
|
op.getContext(), llvm::ArrayRef(broadcastShape), dtype);
|
||||||
|
|
||||||
|
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
|
||||||
|
broadcastShapeValue);
|
||||||
|
|
||||||
|
// broadcast tensors to common shape
|
||||||
|
auto a = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, self,
|
||||||
|
indexBroadcastShapeTorchList);
|
||||||
|
auto b = rewriter.create<AtenBroadcastToOp>(loc, broadcastType, other,
|
||||||
|
indexBroadcastShapeTorchList);
|
||||||
|
|
||||||
|
// create constants
|
||||||
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
Value constTwo = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(2));
|
||||||
|
Value constThree = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(3));
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
|
||||||
|
// idx = torch.arange(3)
|
||||||
|
auto outType = opType.dyn_cast<BaseTensorType>();
|
||||||
|
auto arangeType = outType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef<int64_t>(3),
|
||||||
|
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
|
||||||
|
auto idx = rewriter.create<AtenArangeOp>(
|
||||||
|
loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memory=*/none);
|
||||||
|
|
||||||
|
// (idx + 1) and (idx + 2)
|
||||||
|
auto idxPlusOne = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
|
||||||
|
constOne, constOne);
|
||||||
|
auto idxPlusTwo = rewriter.create<AtenAddScalarOp>(loc, arangeType, idx,
|
||||||
|
constTwo, constOne);
|
||||||
|
|
||||||
|
// (idx + 1) % 3 and (idx + 2) % 3
|
||||||
|
auto idxPlusOneRemainderThree = rewriter.create<AtenRemainderScalarOp>(
|
||||||
|
loc, arangeType, idxPlusOne, constThree);
|
||||||
|
auto idxPlusTwoRemainderThree = rewriter.create<AtenRemainderScalarOp>(
|
||||||
|
loc, arangeType, idxPlusTwo, constThree);
|
||||||
|
|
||||||
|
// a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3)
|
||||||
|
auto idxSelectAPlusOne = rewriter.create<AtenIndexSelectOp>(
|
||||||
|
loc, opType, a, dim, idxPlusOneRemainderThree);
|
||||||
|
auto idxSelectBPlusTwo = rewriter.create<AtenIndexSelectOp>(
|
||||||
|
loc, opType, b, dim, idxPlusTwoRemainderThree);
|
||||||
|
auto firstMul = rewriter.create<AtenMulTensorOp>(
|
||||||
|
loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo);
|
||||||
|
|
||||||
|
// a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3)
|
||||||
|
auto idxSelectAPlusTwo = rewriter.create<AtenIndexSelectOp>(
|
||||||
|
loc, opType, a, dim, idxPlusTwoRemainderThree);
|
||||||
|
auto idxSelectBPlusOne = rewriter.create<AtenIndexSelectOp>(
|
||||||
|
loc, opType, b, dim, idxPlusOneRemainderThree);
|
||||||
|
auto secondMul = rewriter.create<AtenMulTensorOp>(
|
||||||
|
loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne);
|
||||||
|
|
||||||
|
// subtract the results of the two multiplications from above
|
||||||
|
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, opType, firstMul,
|
||||||
|
secondMul, constOne);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
|
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
|
||||||
// prims.collapse operations.
|
// prims.collapse operations.
|
||||||
//
|
//
|
||||||
|
@ -7081,6 +7192,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenSelectIntOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMatmulOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenMvOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
|
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
|
||||||
|
|
|
@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<AtenNormScalarOptDimOp>();
|
target.addIllegalOp<AtenNormScalarOptDimOp>();
|
||||||
target.addIllegalOp<AtenSelectIntOp>();
|
target.addIllegalOp<AtenSelectIntOp>();
|
||||||
target.addIllegalOp<AtenMvOp>();
|
target.addIllegalOp<AtenMvOp>();
|
||||||
|
target.addIllegalOp<AtenLinalgCrossOp>();
|
||||||
target.addIllegalOp<AtenPixelShuffleOp>();
|
target.addIllegalOp<AtenPixelShuffleOp>();
|
||||||
target.addIllegalOp<AtenTOp>();
|
target.addIllegalOp<AtenTOp>();
|
||||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||||
|
|
|
@ -2107,6 +2107,9 @@ ONNX_XFAIL_SET = {
|
||||||
"ReduceMinAlongDimUnsignedInt_basic",
|
"ReduceMinAlongDimUnsignedInt_basic",
|
||||||
"TensorsStackNegativeDimModule_basic",
|
"TensorsStackNegativeDimModule_basic",
|
||||||
"TensorsStackPromoteDTypeModule_basic",
|
"TensorsStackPromoteDTypeModule_basic",
|
||||||
|
|
||||||
|
# Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1"
|
||||||
|
"AtenLinalgCrossDynamic_basic"
|
||||||
}
|
}
|
||||||
|
|
||||||
ONNX_CRASHING_SET = { }
|
ONNX_CRASHING_SET = { }
|
||||||
|
|
|
@ -384,6 +384,17 @@ def aten〇clone〡shape(self: List[int], memory_format: Optional[int] = None) -
|
||||||
def aten〇lift_fresh_copy〡shape(self: List[int]) -> List[int]:
|
def aten〇lift_fresh_copy〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@check_shape_function([
|
||||||
|
Invocation(TensorOfShape(1, 2, 3), TensorOfShape(4, 1, 3)), # two dimensions to broadcast, self[0] and other[1]
|
||||||
|
ErrorInvocation(TensorOfShape(3), TensorOfShape(2, 3)), # different number of dimensions
|
||||||
|
ErrorInvocation(TensorOfShape(2, 3), TensorOfShape(4, 3)) # non-broadcastable dimensions
|
||||||
|
])
|
||||||
|
def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1) -> List[int]:
|
||||||
|
assert len(self) == len(other), "inputs must have the same number of dimensions"
|
||||||
|
for i in range(len(self)):
|
||||||
|
assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}"
|
||||||
|
return upstream_shape_functions.broadcast(self, other)
|
||||||
|
|
||||||
def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
|
def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]:
|
||||||
return upstream_shape_functions.unary(grad_output)
|
return upstream_shape_functions.unary(grad_output)
|
||||||
|
|
||||||
|
@ -2381,6 +2392,19 @@ def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
_check_tensors_with_the_same_dtype(tensor_device="cpu", tensor_shapes=[(2,3), (2,3)], error_types={torch.bool}) + # same dtype
|
||||||
|
[ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32, device="cpu"), TensorOfShape(2, 3, dtype=torch.float16, device="cpu"))] #different dtypes
|
||||||
|
)
|
||||||
|
def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], dim: int = -1) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
other_rank, other_dtype = other_rank_dtype
|
||||||
|
ranks: List[Optional[int]] = [self_rank, other_rank]
|
||||||
|
assert self_dtype == other_dtype
|
||||||
|
assert self_dtype != torch.bool
|
||||||
|
dtypes = [self_dtype, other_dtype]
|
||||||
|
return promote_dtypes(ranks, dtypes)
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
|
_check_two_tensor_op(dim=0, input_dtype=torch.float32) +
|
||||||
_check_two_tensor_op(dim=0, input_dtype=torch.float64))
|
_check_two_tensor_op(dim=0, input_dtype=torch.float64))
|
||||||
|
|
|
@ -687,6 +687,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)")
|
||||||
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
|
emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True)
|
||||||
|
|
||||||
# Functionalization ops
|
# Functionalization ops
|
||||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -289,3 +289,114 @@ class AtenMmQuint8(torch.nn.Module):
|
||||||
def AtenMmQuint8_basic(module, tu: TestUtils):
|
def AtenMmQuint8_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8),
|
||||||
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
tu.randint(4, 3, low=-128, high=127).to(torch.int8))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossInt(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3], torch.int64, True),
|
||||||
|
([2, 3], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossInt())
|
||||||
|
def AtenLinalgCrossInt_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 3), tu.randint(2, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossFloat(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
([2, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossFloat())
|
||||||
|
def AtenLinalgCrossFloat_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 3), tu.rand(2, 3))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossBroadcast(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 4, 3], torch.float32, True),
|
||||||
|
([5, 4, 3], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossBroadcast())
|
||||||
|
def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossCustomDim(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 4, 3, 2, 2], torch.float32, True),
|
||||||
|
([5, 4, 3, 2, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b, dim=2)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossCustomDim())
|
||||||
|
def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossNegativeDim(torch.nn.Module):
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([1, 4, 3, 2, 2], torch.float32, True),
|
||||||
|
([5, 4, 3, 2, 1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b, dim=-3)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossNegativeDim())
|
||||||
|
def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class AtenLinalgCrossDynamic(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, a, b):
|
||||||
|
return torch.ops.aten.linalg_cross(a, b, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
|
||||||
|
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
|
Loading…
Reference in New Issue