mirror of https://github.com/llvm/torch-mlir
Implement lowering of torch.aten.tril_indices (#3517)
parent
f0ce1e94ce
commit
c7d972ed58
|
@ -15857,6 +15857,36 @@ def Torch_AtenTriuIndicesOp : Torch_Op<"aten.triu_indices", [
|
|||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
Torch_IntType:$row,
|
||||
Torch_IntType:$col,
|
||||
Torch_IntType:$offset,
|
||||
AnyTorchOptionalIntType:$dtype,
|
||||
AnyTorchOptionalIntType:$layout,
|
||||
AnyTorchOptionalDeviceType:$device,
|
||||
AnyTorchOptionalBoolType:$pin_memory
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenTrilIndicesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void AtenTrilIndicesOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -5252,3 +5252,40 @@ LogicalResult AtenTriuIndicesOp::verify() {
|
|||
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenTrilIndicesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AtenTrilIndicesOp::verify() {
|
||||
|
||||
// Check if row, col and offset are constant ints
|
||||
int64_t row;
|
||||
if (!matchPattern(getRow(), m_TorchConstantInt(&row)))
|
||||
return success();
|
||||
|
||||
int64_t col;
|
||||
if (!matchPattern(getCol(), m_TorchConstantInt(&col)))
|
||||
return success();
|
||||
|
||||
int64_t offset;
|
||||
if (!matchPattern(getOffset(), m_TorchConstantInt(&offset)))
|
||||
return success();
|
||||
|
||||
// Check if values of row, and col are valid
|
||||
if (row < 0)
|
||||
return emitOpError("row must be non-negative, got ") << row;
|
||||
|
||||
if (col < 0)
|
||||
return emitOpError("col must be non-negative, got ") << col;
|
||||
|
||||
// Check if dtype is valid
|
||||
int64_t dtype;
|
||||
if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype)))
|
||||
return success();
|
||||
if (dtype != (int)torch_upstream::ScalarType::Int &&
|
||||
dtype != (int)torch_upstream::ScalarType::Long)
|
||||
return emitOpError(
|
||||
"'triu_indices' implemented only for torch.int32 and torch.int64");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -9993,6 +9993,53 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.eq.int %arg0, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.eq.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %3 : !torch.bool\n"
|
||||
" }\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.list<int>) {\n"
|
||||
" %3 = torch.prim.ListConstruct %int2, %int0 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %3 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||
" %21 = torch.aten.add.int %int1, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %22 = torch.prim.min.int %arg1, %21 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" torch.prim.If.yield %22 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %21 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %22 = torch.aten.gt.int %21, %int0 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %23 = torch.aten.Int.bool %22 : !torch.bool -> !torch.int\n"
|
||||
" torch.prim.If.yield %23 : !torch.int\n"
|
||||
" }\n"
|
||||
" %5 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %6 = torch.prim.min.int %arg1, %5 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %7 = torch.prim.max.int %int0, %6 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %8 = torch.aten.add.int %arg0, %arg2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %9 = torch.prim.min.int %arg0, %8 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %10 = torch.prim.max.int %int0, %9 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %11 = torch.aten.sub.int %7, %4 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %12 = torch.aten.add.int %11, %int1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %13 = torch.aten.add.int %4, %7 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %14 = torch.aten.mul.int %13, %12 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %15 = torch.aten.floordiv.int %14, %int2 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %16 = torch.aten.sub.int %10, %12 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %17 = torch.aten.mul.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %18 = torch.prim.max.int %int0, %17 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %19 = torch.aten.add.int %15, %18 : !torch.int, !torch.int -> !torch.int\n"
|
||||
" %20 = torch.prim.ListConstruct %int2, %19 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %20 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||
|
@ -14773,6 +14820,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.tril_indices\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0 = torch.aten.__is__ %arg3, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||
" torch.prim.If.yield %int4 : !torch.int\n"
|
||||
" } else {\n"
|
||||
" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional<int> -> !torch.int\n"
|
||||
" torch.prim.If.yield %2 : !torch.int\n"
|
||||
" }\n"
|
||||
" return %1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
|
|
|
@ -992,6 +992,214 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// decomposition of torch.tril_indices
|
||||
// https://github.com/pytorch/pytorch/blob/67ef2683d970fc541b6d266d4b3f8ba9d13844ca/torch/_refs/__init__.py#L5797
|
||||
namespace {
|
||||
class DecomposeAtenTrilIndicesOp : public OpRewritePattern<AtenTrilIndicesOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTrilIndicesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op.getContext();
|
||||
|
||||
// Required parameters
|
||||
Value row = op.getRow();
|
||||
Value col = op.getCol();
|
||||
Value offset = op.getOffset();
|
||||
|
||||
// Check if row, col and offset are constant ints
|
||||
int64_t rowInt;
|
||||
if (!matchPattern(row, m_TorchConstantInt(&rowInt)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: row not constant int");
|
||||
|
||||
int64_t colInt;
|
||||
if (!matchPattern(col, m_TorchConstantInt(&colInt)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Unimplemented: col not constant int");
|
||||
|
||||
int64_t offsetInt;
|
||||
if (!matchPattern(offset, m_TorchConstantInt(&offsetInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: offset not constant int");
|
||||
|
||||
// Optional parameters
|
||||
Value dtype = op.getDtype();
|
||||
Value layout = op.getLayout();
|
||||
Value device = op.getDevice();
|
||||
Value pinMemory = op.getPinMemory();
|
||||
|
||||
// Constants
|
||||
Value cstZero = rewriter.create<Torch::ConstantIntOp>(loc, 0);
|
||||
Value cstOne = rewriter.create<Torch::ConstantIntOp>(loc, 1);
|
||||
Value cstTwo = rewriter.create<Torch::ConstantIntOp>(loc, 2);
|
||||
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value cstZeroPointFive = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(0.5));
|
||||
Value cstTwoFloat = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(2.0));
|
||||
|
||||
// Get int value for dtype
|
||||
int64_t dtypeInt;
|
||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: dtype not constant int");
|
||||
|
||||
FailureOr<Type> dtypeType =
|
||||
getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt);
|
||||
if (failed(dtypeType))
|
||||
return rewriter.notifyMatchFailure(op, "dtype is undefined");
|
||||
|
||||
// Calculte trapezoidSize, rectangleSize and mFirstRow
|
||||
std::tuple<int64_t, int64_t, int64_t> triuSizes =
|
||||
getTrilSizes(rowInt, colInt, offsetInt);
|
||||
|
||||
int64_t trapezoidSizeInt = std::get<0>(triuSizes);
|
||||
int64_t rectangleSizeInt = std::get<1>(triuSizes);
|
||||
int64_t mFirstRowInt = std::get<2>(triuSizes);
|
||||
|
||||
// Create const int Values from ints
|
||||
Value trapezoidSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(trapezoidSizeInt));
|
||||
Value rectangleSize = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(rectangleSizeInt));
|
||||
Value mFirstRow = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(mFirstRowInt));
|
||||
|
||||
// Calculte column offset
|
||||
int64_t rowOffsetInt = (-offsetInt > 0) ? (-offsetInt) : 0;
|
||||
Value rowOffset = rewriter.create<Torch::ConstantIntOp>(loc, rowOffsetInt);
|
||||
|
||||
// First we do the indices for TOP trapezoid
|
||||
auto f64DtypeInt =
|
||||
getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type());
|
||||
auto arrangeType =
|
||||
getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type());
|
||||
Value xs1 =
|
||||
rewriter.create<AtenArangeOp>(loc, arrangeType, trapezoidSize,
|
||||
/*dtype=*/f64DtypeInt, /*layout=*/layout,
|
||||
/*device=*/device,
|
||||
/*pin_memory=*/pinMemory);
|
||||
|
||||
// b = m_first_row - 0.5
|
||||
Value mFirstRowFloat = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(mFirstRowInt));
|
||||
Value b =
|
||||
rewriter.create<AtenSubFloatOp>(loc, mFirstRowFloat, cstZeroPointFive);
|
||||
|
||||
// Implements this piece of code: row_inds1 = torch.floor(-b + torch.sqrt(b
|
||||
// * b + 2 * xs1))
|
||||
Value bSquare = rewriter.create<AtenMulFloatOp>(loc, b, b);
|
||||
|
||||
Value twoTimesXs1 =
|
||||
rewriter.create<AtenMulScalarOp>(loc, xs1.getType(), xs1, cstTwoFloat);
|
||||
Value sqrtInput = rewriter.create<AtenAddScalarOp>(
|
||||
loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne);
|
||||
|
||||
Value sqrt =
|
||||
rewriter.create<AtenSqrtOp>(loc, sqrtInput.getType(), sqrtInput);
|
||||
|
||||
Value rowInds1 =
|
||||
rewriter.create<AtenSubScalarOp>(loc, sqrt.getType(), sqrt, b, cstOne);
|
||||
rowInds1 = rewriter.create<AtenFloorOp>(loc, rowInds1.getType(), rowInds1);
|
||||
|
||||
// Implements this piece of code: col_inds1 = torch.floor(xs1 - (2 *
|
||||
// m_first_row - 1 + row_inds1) * row_inds1 * 0.5)
|
||||
Value twoTimesMFirstRow =
|
||||
rewriter.create<AtenMulIntOp>(loc, cstTwo, mFirstRow);
|
||||
twoTimesMFirstRow =
|
||||
rewriter.create<AtenSubIntOp>(loc, twoTimesMFirstRow, cstOne);
|
||||
twoTimesMFirstRow = rewriter.create<AtenAddScalarOp>(
|
||||
loc, rowInds1.getType(), rowInds1, twoTimesMFirstRow, cstOne);
|
||||
twoTimesMFirstRow = rewriter.create<AtenMulTensorOp>(
|
||||
loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, rowInds1);
|
||||
twoTimesMFirstRow = rewriter.create<AtenMulScalarOp>(
|
||||
loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, cstZeroPointFive);
|
||||
|
||||
Value colInds1 = rewriter.create<AtenSubTensorOp>(
|
||||
loc, xs1.getType(), xs1, twoTimesMFirstRow, cstOne);
|
||||
colInds1 = rewriter.create<AtenFloorOp>(loc, colInds1.getType(), colInds1);
|
||||
|
||||
// Convert top trapezoid indices to dtype
|
||||
Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true);
|
||||
|
||||
auto rowInds1Type = cast<BaseTensorType>(rowInds1.getType());
|
||||
ArrayRef<int64_t> sizes = rowInds1Type.getSizes();
|
||||
Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type);
|
||||
rowInds1 = rewriter.create<AtenAddScalarOp>(loc, rowInds1.getType(),
|
||||
rowInds1, rowOffset, cstOne);
|
||||
rowInds1 = rewriter.create<AtenToDtypeOp>(
|
||||
loc, finalRowType, rowInds1, dtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/cstOne);
|
||||
|
||||
auto colInds1Type = cast<BaseTensorType>(colInds1.getType());
|
||||
sizes = colInds1Type.getSizes();
|
||||
Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type);
|
||||
colInds1 = rewriter.create<AtenToDtypeOp>(
|
||||
loc, finalColType, colInds1, dtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/cstOne);
|
||||
|
||||
// Calculate indices for BOTTOM rectangle
|
||||
arrangeType = getTensorTypeFromShapeValues({rectangleSize}, *dtypeType);
|
||||
Value xs2 =
|
||||
rewriter.create<AtenArangeOp>(loc, arrangeType, rectangleSize,
|
||||
/*dtype=*/dtype, /*layout=*/layout,
|
||||
/*device=*/device,
|
||||
/*pin_memory=*/pinMemory);
|
||||
|
||||
// Implements this line of code: row_inds2 = xs2 // col + (col - m_first_row
|
||||
// + 1 + row_offset)
|
||||
Value rowInds2 =
|
||||
rewriter.create<AtenFloorDivideScalarOp>(loc, xs2.getType(), xs2, col);
|
||||
int64_t addInt = colInt - mFirstRowInt + 1 + rowOffsetInt;
|
||||
Value cstAdd = rewriter.create<Torch::ConstantIntOp>(loc, addInt);
|
||||
rowInds2 = rewriter.create<AtenAddScalarOp>(loc, rowInds2.getType(),
|
||||
rowInds2, cstAdd, cstOne);
|
||||
|
||||
// Implements this line of code: col_inds2 = xs2 % col
|
||||
Value colInds2 =
|
||||
rewriter.create<AtenRemainderScalarOp>(loc, xs2.getType(), xs2, col);
|
||||
|
||||
// Prepare tensors for concatenation
|
||||
Type listElemType =
|
||||
cast<Torch::BaseTensorType>(rowInds1.getType())
|
||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
|
||||
Value sequenceRow = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, listType, SmallVector<Value>{rowInds1, rowInds2});
|
||||
Value sequenceCol = rewriter.create<Torch::PrimListConstructOp>(
|
||||
loc, listType, SmallVector<Value>{colInds1, colInds2});
|
||||
|
||||
// Concatenate row and col indices
|
||||
Type finalCatType = colInds1Type.getWithSizesAndDtype(
|
||||
{rectangleSizeInt + trapezoidSizeInt}, int64Type);
|
||||
|
||||
Value catRow = rewriter.create<AtenCatOp>(loc, finalCatType, sequenceRow,
|
||||
/*dim=*/cstZero);
|
||||
Value catCol = rewriter.create<AtenCatOp>(loc, finalCatType, sequenceCol,
|
||||
/*dim=*/cstZero);
|
||||
|
||||
// Make return value - stack row and col indices
|
||||
Value sequence = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(context, rowInds1.getType()),
|
||||
ValueRange{catRow, catCol});
|
||||
Type finalStackType = colInds1Type.getWithSizesAndDtype(
|
||||
ArrayRef<int64_t>{2, rectangleSizeInt + trapezoidSizeInt}, int64Type);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenStackOp>(op, finalStackType, sequence,
|
||||
cstZero);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
|
||||
public:
|
||||
|
@ -9063,6 +9271,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuIndicesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTrilIndicesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
|
|
|
@ -548,6 +548,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenReshapeAsOp>();
|
||||
target.addIllegalOp<AtenTriuOp>();
|
||||
target.addIllegalOp<AtenTriuIndicesOp>();
|
||||
target.addIllegalOp<AtenTrilIndicesOp>();
|
||||
target.addIllegalOp<AtenLinalgNormOp>();
|
||||
target.addIllegalOp<AtenFminOp>();
|
||||
target.addIllegalOp<AtenFmaxOp>();
|
||||
|
|
|
@ -1337,6 +1337,10 @@ STABLEHLO_PASS_SET = {
|
|||
"TriuIndicesModule_basic",
|
||||
"TriuIndicesAllZerosModule_basic",
|
||||
"TriuIndicesNegativeOffsetModule_basic",
|
||||
"TrilIndicesAllZerosModule_basic",
|
||||
"TrilIndicesModule_basic",
|
||||
"TrilIndicesNegativeOffsetModule_basic",
|
||||
"TrilIndicesOfssetGreaterThanRowModule_basic",
|
||||
"TupleModule_basic",
|
||||
"TypeAsDifferentModule_basic",
|
||||
"TypeAsSameModule_basic",
|
||||
|
|
|
@ -1882,6 +1882,29 @@ def aten〇triu_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti
|
|||
|
||||
return [2, triu_size]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(4, 3, 1), # Basic case.
|
||||
Invocation(0, 0, 0), # All zeros case.
|
||||
Invocation(7, 5, -2), # Negative offset case.
|
||||
Invocation(35, 55, 16), # Largere values case.
|
||||
])
|
||||
def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||
if row == 0 or col == 0:
|
||||
return [2, 0]
|
||||
|
||||
m_first_row = min(col, 1 + offset) if offset > 0 else int(row + offset > 0)
|
||||
m_last_row = max(0, min(col, row + offset))
|
||||
n_row_all = max(0, min(row, row + offset))
|
||||
n_row_trapezoid = m_last_row - m_first_row + 1
|
||||
|
||||
# Number of elements in top trapezoid
|
||||
trapezoid_size = (m_first_row + m_last_row) * n_row_trapezoid // 2
|
||||
# Number of elements in bottom rectangle
|
||||
diff_row = n_row_all - n_row_trapezoid
|
||||
rectangle_size = max(0, diff_row * col)
|
||||
|
||||
return [2, trapezoid_size + rectangle_size]
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
|
||||
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
|
||||
|
@ -5254,6 +5277,9 @@ def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> i
|
|||
def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.int64 if dtype is None else dtype
|
||||
|
||||
def aten〇tril_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int:
|
||||
return torch.int64 if dtype is None else dtype
|
||||
|
||||
def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
if (self_dtype == torch.quint8):
|
||||
|
|
|
@ -1090,6 +1090,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
has_verifier=True,
|
||||
)
|
||||
|
||||
emit(
|
||||
"aten::tril_indices : (int, int, int, int?, int?, Device?, bool?) -> (Tensor)",
|
||||
has_verifier=True,
|
||||
)
|
||||
|
||||
# backprop ops
|
||||
emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)")
|
||||
emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)")
|
||||
|
|
|
@ -6364,3 +6364,82 @@ class TriuIndicesNegativeOffsetModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: TriuIndicesNegativeOffsetModule())
|
||||
def TriuIndicesNegativeOffsetModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TrilIndicesModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.tril_indices(4, 3, 1)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TrilIndicesModule())
|
||||
def TrilIndicesModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class TrilIndicesAllZerosModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.tril_indices(0, 0, 0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TrilIndicesAllZerosModule())
|
||||
def TrilIndicesAllZerosModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class TrilIndicesNegativeOffsetModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.tril_indices(5, 16, -2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TrilIndicesNegativeOffsetModule())
|
||||
def TrilIndicesNegativeOffsetModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
||||
|
||||
class TrilIndicesOfssetGreaterThanRowModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
]
|
||||
)
|
||||
def forward(self):
|
||||
return torch.ops.aten.tril_indices(7, 9, 8)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule())
|
||||
def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils):
|
||||
module.forward()
|
||||
|
|
Loading…
Reference in New Issue