[TorchToLinalg][ONNX] Add Basic Determinant Support (#3481)

This adds support for a few ops:

- torch.linalg_det
- torch._linalg_det (if the LU and pivot returns are unused)
- onnx.Det

An scf loop is used, since the row reduction algorithm applied here has
some loop-carried dependencies.
The current support being added here is very basic, and only works if no
permutations are required during row reduction, and assumes the matrices
are non-singular.
pull/3504/head
zjgarvey 2024-06-25 13:34:19 -05:00 committed by GitHub
parent 368fabf0c1
commit d2bc70f188
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 459 additions and 1 deletions

View File

@ -8586,6 +8586,54 @@ def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [
}];
}
def Torch_AtenLinalgDetOp : Torch_Op<"aten.linalg_det", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::linalg_det : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$A
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLinalgDetOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLinalgDetOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$A
);
let results = (outs
AnyTorchOptionalTensorType:$result,
AnyTorchOptionalTensorType:$LU,
AnyTorchOptionalTensorType:$pivots
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_LinalgDetOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 3);
}
void Aten_LinalgDetOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 3);
}
}];
}
def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1972,6 +1972,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
useMaskValue);
return success();
});
patterns.onOp(
"Det", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
if (binder.tensorOperand(input) || binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenLinalgDetOp>(binder.op,
resultType, input);
return success();
});
patterns.onOp(
"DequantizeLinear", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -16,6 +16,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
@ -42,6 +43,7 @@ public:
registry.insert<tensor::TensorDialect>();
registry.insert<arith::ArithDialect>();
registry.insert<cf::ControlFlowDialect>();
registry.insert<scf::SCFDialect>();
registry.insert<complex::ComplexDialect>();
TorchConversion::getBackendTypeConversionDependentDialects(registry);
}
@ -51,7 +53,7 @@ public:
ConversionTarget target(*context);
target.addLegalDialect<
linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect,
math::MathDialect, sparse_tensor::SparseTensorDialect,
math::MathDialect, scf::SCFDialect, sparse_tensor::SparseTensorDialect,
tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>();
target.addLegalOp<TorchConversion::GetNextSeedOp>();

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Matchers.h"
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
@ -2952,6 +2953,218 @@ public:
}
};
} // namespace
namespace {
// This pattern row reduces a matrix, then returns the product of it's diagonal
// elements
class ConvertAtenLinalgDetOp : public OpConversionPattern<AtenLinalgDetOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLinalgDetOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
Value input = adaptor.getA();
auto inputType = cast<RankedTensorType>(input.getType());
unsigned inputRank = inputType.getRank();
auto elemTy = inputType.getElementType();
bool isBatched = (inputRank == 3);
Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cstZeroF = getConstant(rewriter, loc, 0, elemTy);
// get some shapes
SmallVector<int64_t> inputShape(inputType.getShape());
SmallVector<int64_t> sliceShape(inputShape);
sliceShape.pop_back();
SmallVector<int64_t> diagShape({isBatched ? inputType.getShape()[0] : 1});
auto sliceTy = RankedTensorType::get(sliceShape, elemTy);
auto diagTy = RankedTensorType::get(diagShape, elemTy);
// get some sizes
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
Value chDim = isBatched ? inputSizes[0] : cstOne;
Value matDim = inputSizes[inputRank - 1];
Value matDimMinusOne = rewriter.create<arith::SubIOp>(loc, matDim, cstOne);
ArrayRef<Value> sliceSizes(inputSizes.begin(), inputSizes.end() - 1);
// initialize a tensor to store the diagonal elements found during row
// reduction
Value initDiags = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(sliceSizes), elemTy);
// loop over each pivot row in A. Get the diagonal, then reduce the
// subdiagonal Don't perform the loop on the last row since no further
// reduction is needed.
auto rowReductionLoop = rewriter.create<scf::ForOp>(
loc, /*start=*/cstZero, /*end=*/matDimMinusOne, /*step=*/cstOne,
/*yeild_to=*/ValueRange{input, initDiags}, /*body_lambda=*/
[&](OpBuilder &b, Location loc, Value row, ValueRange vals) {
// extract row i from input Tensor of shape CxNxN or shape
// NxN.
OpFoldResult cstOneFold = getAsOpFoldResult(cstOne);
OpFoldResult cstZeroFold = getAsOpFoldResult(cstZero);
SmallVector<OpFoldResult> offsets(inputRank, cstZeroFold);
offsets[inputRank - 2] = row;
SmallVector<OpFoldResult> strides(inputRank, cstOneFold);
auto sizes = getAsOpFoldResult(inputSizes);
sizes[inputRank - 2] = cstOneFold;
// offsets = [0, row, 0], sizes = [C, 1, N] -> pivot row
Value pivot = b.create<tensor::ExtractSliceOp>(
loc, sliceTy, vals[0], offsets, sizes, strides);
// extract diagonal elements and insert them into vals[1]
offsets.back() = row;
sizes.back() = cstOneFold;
// offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row)
Value diag = b.create<tensor::ExtractSliceOp>(
loc, diagTy, vals[0], offsets, sizes, strides);
SmallVector<OpFoldResult> diagOffsets(inputRank - 1, cstZeroFold);
diagOffsets.back() = row;
SmallVector<OpFoldResult> diagStrides(inputRank - 1, cstOneFold);
SmallVector<OpFoldResult> diagSizes = getAsOpFoldResult(sliceSizes);
diagSizes.back() = cstOneFold;
// offsets = [0, row], sizes = [C, 1] insert to [C,N]
Value updatedDiags = b.create<tensor::InsertSliceOp>(
loc, diag, vals[1], diagOffsets, diagSizes, diagStrides);
// the subpivot matrix column size, as a Value, is matDim - row -
// cstOne. This can't be statically converted to an int64_t, since row
// is the loop index, so this is left as a dynamic dim.
SmallVector<int64_t> subPivotShape(inputType.getShape());
subPivotShape[inputRank - 2] = ShapedType::kDynamic;
ArrayRef<int64_t> subDiagShape(subPivotShape.begin(),
subPivotShape.end() - 1);
auto subPivotTy = RankedTensorType::get(subPivotShape, elemTy);
auto subDiagTy = RankedTensorType::get(subDiagShape, elemTy);
Value rowPlusOne = b.create<arith::AddIOp>(loc, row, cstOne);
offsets[inputRank - 2] = getAsOpFoldResult(rowPlusOne);
sizes[inputRank - 2] = getAsOpFoldResult(
b.create<arith::SubIOp>(loc, matDim, rowPlusOne));
// offsets = [0, row + 1, row], sizes = [C, N - row - 1, 1] -> A_j,row
// with j > row
Value subDiag = b.create<tensor::ExtractSliceOp>(
loc, subDiagTy, vals[0], offsets, sizes, strides);
offsets.back() = cstZeroFold;
sizes.back() = getAsOpFoldResult(matDim);
// offsets = [0, row + 1, 0], sizes = [C, N - row - 1, N] -> elements
// below pivot row
Value subPivot = b.create<tensor::ExtractSliceOp>(
loc, subPivotTy, vals[0], offsets, sizes, strides);
Value initResult = b.create<tensor::EmptyOp>(loc, sizes, elemTy);
// write a generic op to perform subpivot = subpivot -
// (subdiag/diag)*pivot
// d0 = batches, d1 = row, d2 = column -> pivot(d0,d2), diag(d0),
// subPivot(d0,d1,d2), subDiag(d0, d1); output(d0,d1,d2)
SmallVector<AffineExpr> allDims;
for (unsigned i = 0; i < inputRank; i++)
allDims.push_back(b.getAffineDimExpr(i));
SmallVector<AffineExpr> rowIterator(1, allDims[0]);
SmallVector<AffineExpr> colIterator;
SmallVector<AffineExpr> batchIterator;
if (isBatched) {
rowIterator.push_back(allDims[1]);
colIterator.push_back(allDims[0]);
colIterator.push_back(allDims[2]);
batchIterator.push_back(allDims[0]);
} else {
colIterator.push_back(allDims[1]);
batchIterator.push_back(getAffineConstantExpr(0, context));
}
SmallVector<AffineMap> indexingMaps;
indexingMaps.push_back(
AffineMap::get(inputRank, 0, colIterator, context));
indexingMaps.push_back(
AffineMap::get(inputRank, 0, batchIterator, context));
indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank));
indexingMaps.push_back(
AffineMap::get(inputRank, 0, rowIterator, context));
indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank));
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value reducedSubPivot =
b.create<linalg::GenericOp>(
loc, subPivotTy, ValueRange{pivot, diag, subPivot, subDiag},
initResult, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// for d0 in batches, d1 in subpivotrows, d2 in columns
// let i represent the pivot row index (scf loop index)
Value pivotd0d2 = args[0];
Value diagd0 = args[1];
Value subPivotd0d1d2 = args[2];
Value subDiagd0d1 = args[3];
// coeff = A_d1,i / A_i,i
Value coeff =
b.create<arith::DivFOp>(loc, subDiagd0d1, diagd0);
auto cmp = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF);
b.create<cf::AssertOp>(
loc, cmp,
b.getStringAttr(
"unimplemented: determinants requiring "
"permutations and singular matrices"));
// coeff*A_i,d2
Value scaledPivotValue =
b.create<arith::MulFOp>(loc, coeff, pivotd0d2);
// result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2
// so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0
Value result = b.create<arith::SubFOp>(loc, subPivotd0d1d2,
scaledPivotValue);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
Value rowReductionResult = b.create<tensor::InsertSliceOp>(
loc, reducedSubPivot, vals[0], offsets, sizes, strides);
b.create<scf::YieldOp>(loc,
ValueRange{rowReductionResult, updatedDiags});
});
Value allDiagsExceptLast = rowReductionLoop.getResult(1);
SmallVector<OpFoldResult> offsets(inputRank,
getAsOpFoldResult(matDimMinusOne));
SmallVector<OpFoldResult> strides(inputRank, getAsOpFoldResult(cstOne));
SmallVector<OpFoldResult> sizes(inputRank, getAsOpFoldResult(cstOne));
sizes[0] = getAsOpFoldResult(chDim);
if (isBatched)
offsets[0] = getAsOpFoldResult(cstZero);
Value lastDiag = rewriter.create<tensor::ExtractSliceOp>(
loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, strides);
offsets.pop_back();
strides.pop_back();
sizes.pop_back();
Value allDiags = rewriter.create<tensor::InsertSliceOp>(
loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides);
// linalg generic to do reduce prod for allDiags along back dim.
// the result of that generic will be the determinant
SmallVector<AffineMap> indexingMaps;
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(inputRank - 1));
AffineExpr resultExpr = isBatched ? rewriter.getAffineDimExpr(0)
: getAffineConstantExpr(0, context);
indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr));
SmallVector<utils::IteratorType> iteratorTypes(
inputRank - 1, utils::IteratorType::parallel);
Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy,
getConstant(rewriter, loc, 1.0, elemTy));
Value determinant =
rewriter
.create<linalg::GenericOp>(
loc, initDet.getType(), ValueRange{allDiags}, initDet,
indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value prod = b.create<arith::MulFOp>(loc, args[0], args[1]);
b.create<linalg::YieldOp>(loc, prod);
})
.getResult(0);
Type newResultType =
getTypeConverter()->convertType(op.getResult().getType());
if (isBatched) {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
determinant);
return success();
}
Value detVal = rewriter.create<tensor::ExtractOp>(
loc, determinant, SmallVector<Value>(1, cstZero));
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, newResultType,
ValueRange{detVal});
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
@ -3009,4 +3222,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
target.addIllegalOp<Aten__InterpolateSizeListScaleListOp>();
patterns.add<ConvertInterpolateOp>(typeConverter, context);
target.addIllegalOp<AtenLinalgDetOp>();
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
}

View File

@ -6485,6 +6485,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %int-2 = torch.constant.int -2\n"
" %int-1 = torch.constant.int -1\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %true = torch.constant.bool true\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
" torch.prim.If.yield %true : !torch.bool\n"
" } else {\n"
" %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %10 : !torch.bool\n"
" }\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.list<int>) {\n"
" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %9 : !torch.list<int>\n"
" } else {\n"
" %9 = torch.derefine %arg0 : !torch.list<int> to !torch.any\n"
" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list<int>\n"
" torch.prim.If.yield %10 : !torch.list<int>\n"
" }\n"
" return %8 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._linalg_det\"(%arg0: !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
" %none = torch.constant.none\n"
" %int1 = torch.constant.int 1\n"
" %int-1 = torch.constant.int -1\n"
" %0 = call @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = torch.aten.slice.t %arg0, %none, %int-1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %arg0, %1 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._linalg_det\"(%arg0: !torch.tuple<int, int>) -> !torch.tuple<int, int, int> {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %1 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
" return %3 : !torch.tuple<int, int, int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -10986,6 +11048,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_det\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2619,6 +2619,28 @@ public:
};
} // namespace
namespace {
class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_LinalgDetOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value, 3> results = op.getResults();
if (!results[1].use_empty() || !results[2].use_empty())
return rewriter.notifyMatchFailure(
op, "unsupported: _linalg_det results: LU and pivot");
Location loc = op.getLoc();
Value input = op.getA();
Value determinant = rewriter.create<Torch::AtenLinalgDetOp>(
loc, results[0].getType(), input);
rewriter.replaceAllUsesWith(results[0], determinant);
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
// prims.collapse operations.
//
@ -8701,6 +8723,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuIndicesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
// More specific conv ops

View File

@ -404,6 +404,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMvOp>();
target.addIllegalOp<AtenRenormOp>();
target.addIllegalOp<AtenLinalgCrossOp>();
target.addIllegalOp<Aten_LinalgDetOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();

View File

@ -559,6 +559,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ConvolutionBackwardModule2D_basic",
"CumsumModule_basic",
"DeformConv2D_basic",
"DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32",
"DeterminantModule_F32",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
@ -2939,6 +2942,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"DeformConv2D_basic",
"DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32",
"DeterminantModule_F32",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",
@ -3734,6 +3740,10 @@ ONNX_TOSA_XFAIL_SET = {
"CumsumStaticModule_basic",
"CumsumStaticNegativeDimModule_basic",
"DeformConv2D_basic",
"DeterminantModule_F32",
"DeterminantBatchedModule_F32",
"DeterminantDynamicModule_F32",
"DeterminantModule_F32",
"DiagonalModule_basic",
"DiagonalModule_nonsquare",
"DiagonalModule_transposed",

View File

@ -223,6 +223,19 @@ def atensign〡shape(self: List[int]) -> List[int]:
def atensgn〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenlinalg_det〡shape(A: List[int]) -> List[int]:
assert len(A) == 2 or len(A) == 3
assert A[-1] == A[-2]
if len(A) == 3:
return A[:1]
return upstream_shape_functions.zero_dim_tensor(A)
def aten_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List[int]]:
return (atenlinalg_det〡shape(A), A, A[:-1])
def aten_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]:
return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1])
def atendetach〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2630,6 +2643,12 @@ def atendetach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes()}))
def atenlinalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = A_rank_dtype
assert not is_integer_dtype(self_dtype)
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False))
def atendropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int:
input_rank, input_dtype = input_rank_dtype

View File

@ -699,6 +699,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
emit("aten::linalg_det : (Tensor) -> (Tensor)")
emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)")
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")

View File

@ -43,6 +43,7 @@ def register_all_tests():
from . import slice_like
from . import nll_loss
from . import index_select
from . import linalg_algorithms
from . import arange
from . import constant_alloc
from . import threshold

View File

@ -0,0 +1,51 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import torch
from torch_mlir_e2e_test.framework import TestUtils
from torch_mlir_e2e_test.registry import register_test_case
from torch_mlir_e2e_test.annotations import annotate_args, export
# ==============================================================================
class DeterminantModule(torch.nn.Module):
@export
@annotate_args([None, [(4, 4), torch.float32, True]])
def forward(self, A):
return torch.linalg.det(A)
@register_test_case(module_factory=lambda: DeterminantModule())
def DeterminantModule_F32(module, tu: TestUtils):
A = tu.rand(4, 4).to(dtype=torch.float32)
module.forward(A)
class DeterminantBatchedModule(torch.nn.Module):
@export
@annotate_args([None, [(3, 4, 4), torch.float32, True]])
def forward(self, A):
return torch.linalg.det(A)
@register_test_case(module_factory=lambda: DeterminantBatchedModule())
def DeterminantBatchedModule_F32(module, tu: TestUtils):
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
module.forward(A)
class DeterminantDynamicModule(torch.nn.Module):
@export
@annotate_args([None, [(-1, -1, -1), torch.float32, True]])
def forward(self, A):
return torch.linalg.det(A)
@register_test_case(module_factory=lambda: DeterminantBatchedModule())
def DeterminantDynamicModule_F32(module, tu: TestUtils):
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
module.forward(A)