mirror of https://github.com/llvm/torch-mlir
[TorchToLinalg][ONNX] Add Basic Determinant Support (#3481)
This adds support for a few ops: - torch.linalg_det - torch._linalg_det (if the LU and pivot returns are unused) - onnx.Det An scf loop is used, since the row reduction algorithm applied here has some loop-carried dependencies. The current support being added here is very basic, and only works if no permutations are required during row reduction, and assumes the matrices are non-singular.pull/3504/head
parent
368fabf0c1
commit
d2bc70f188
|
@ -8586,6 +8586,54 @@ def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLinalgDetOp : Torch_Op<"aten.linalg_det", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::linalg_det : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$A
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLinalgDetOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLinalgDetOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten_LinalgDetOp : Torch_Op<"aten._linalg_det", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$A
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result,
|
||||
AnyTorchOptionalTensorType:$LU,
|
||||
AnyTorchOptionalTensorType:$pivots
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_LinalgDetOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 3);
|
||||
}
|
||||
void Aten_LinalgDetOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 3);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFrobeniusNormDimOp : Torch_Op<"aten.frobenius_norm.dim", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -1972,6 +1972,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
|||
useMaskValue);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"Det", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input;
|
||||
if (binder.tensorOperand(input) || binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenLinalgDetOp>(binder.op,
|
||||
resultType, input);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"DequantizeLinear", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
@ -42,6 +43,7 @@ public:
|
|||
registry.insert<tensor::TensorDialect>();
|
||||
registry.insert<arith::ArithDialect>();
|
||||
registry.insert<cf::ControlFlowDialect>();
|
||||
registry.insert<scf::SCFDialect>();
|
||||
registry.insert<complex::ComplexDialect>();
|
||||
TorchConversion::getBackendTypeConversionDependentDialects(registry);
|
||||
}
|
||||
|
@ -51,7 +53,7 @@ public:
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<
|
||||
linalg::LinalgDialect, func::FuncDialect, cf::ControlFlowDialect,
|
||||
math::MathDialect, sparse_tensor::SparseTensorDialect,
|
||||
math::MathDialect, scf::SCFDialect, sparse_tensor::SparseTensorDialect,
|
||||
tensor::TensorDialect, arith::ArithDialect, complex::ComplexDialect>();
|
||||
target.addLegalOp<TorchConversion::GetNextSeedOp>();
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
|
||||
#include "torch-mlir/Conversion/Utils/Utils.h"
|
||||
|
@ -2952,6 +2953,218 @@ public:
|
|||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// This pattern row reduces a matrix, then returns the product of it's diagonal
|
||||
// elements
|
||||
class ConvertAtenLinalgDetOp : public OpConversionPattern<AtenLinalgDetOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenLinalgDetOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
Value input = adaptor.getA();
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
unsigned inputRank = inputType.getRank();
|
||||
auto elemTy = inputType.getElementType();
|
||||
bool isBatched = (inputRank == 3);
|
||||
Value cstZero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value cstOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||
Value cstZeroF = getConstant(rewriter, loc, 0, elemTy);
|
||||
// get some shapes
|
||||
SmallVector<int64_t> inputShape(inputType.getShape());
|
||||
SmallVector<int64_t> sliceShape(inputShape);
|
||||
sliceShape.pop_back();
|
||||
SmallVector<int64_t> diagShape({isBatched ? inputType.getShape()[0] : 1});
|
||||
auto sliceTy = RankedTensorType::get(sliceShape, elemTy);
|
||||
auto diagTy = RankedTensorType::get(diagShape, elemTy);
|
||||
// get some sizes
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
Value chDim = isBatched ? inputSizes[0] : cstOne;
|
||||
Value matDim = inputSizes[inputRank - 1];
|
||||
Value matDimMinusOne = rewriter.create<arith::SubIOp>(loc, matDim, cstOne);
|
||||
ArrayRef<Value> sliceSizes(inputSizes.begin(), inputSizes.end() - 1);
|
||||
// initialize a tensor to store the diagonal elements found during row
|
||||
// reduction
|
||||
Value initDiags = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(sliceSizes), elemTy);
|
||||
// loop over each pivot row in A. Get the diagonal, then reduce the
|
||||
// subdiagonal Don't perform the loop on the last row since no further
|
||||
// reduction is needed.
|
||||
auto rowReductionLoop = rewriter.create<scf::ForOp>(
|
||||
loc, /*start=*/cstZero, /*end=*/matDimMinusOne, /*step=*/cstOne,
|
||||
/*yeild_to=*/ValueRange{input, initDiags}, /*body_lambda=*/
|
||||
[&](OpBuilder &b, Location loc, Value row, ValueRange vals) {
|
||||
// extract row i from input Tensor of shape CxNxN or shape
|
||||
// NxN.
|
||||
OpFoldResult cstOneFold = getAsOpFoldResult(cstOne);
|
||||
OpFoldResult cstZeroFold = getAsOpFoldResult(cstZero);
|
||||
SmallVector<OpFoldResult> offsets(inputRank, cstZeroFold);
|
||||
offsets[inputRank - 2] = row;
|
||||
SmallVector<OpFoldResult> strides(inputRank, cstOneFold);
|
||||
auto sizes = getAsOpFoldResult(inputSizes);
|
||||
sizes[inputRank - 2] = cstOneFold;
|
||||
// offsets = [0, row, 0], sizes = [C, 1, N] -> pivot row
|
||||
Value pivot = b.create<tensor::ExtractSliceOp>(
|
||||
loc, sliceTy, vals[0], offsets, sizes, strides);
|
||||
// extract diagonal elements and insert them into vals[1]
|
||||
offsets.back() = row;
|
||||
sizes.back() = cstOneFold;
|
||||
// offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row)
|
||||
Value diag = b.create<tensor::ExtractSliceOp>(
|
||||
loc, diagTy, vals[0], offsets, sizes, strides);
|
||||
SmallVector<OpFoldResult> diagOffsets(inputRank - 1, cstZeroFold);
|
||||
diagOffsets.back() = row;
|
||||
SmallVector<OpFoldResult> diagStrides(inputRank - 1, cstOneFold);
|
||||
SmallVector<OpFoldResult> diagSizes = getAsOpFoldResult(sliceSizes);
|
||||
diagSizes.back() = cstOneFold;
|
||||
// offsets = [0, row], sizes = [C, 1] insert to [C,N]
|
||||
Value updatedDiags = b.create<tensor::InsertSliceOp>(
|
||||
loc, diag, vals[1], diagOffsets, diagSizes, diagStrides);
|
||||
// the subpivot matrix column size, as a Value, is matDim - row -
|
||||
// cstOne. This can't be statically converted to an int64_t, since row
|
||||
// is the loop index, so this is left as a dynamic dim.
|
||||
SmallVector<int64_t> subPivotShape(inputType.getShape());
|
||||
subPivotShape[inputRank - 2] = ShapedType::kDynamic;
|
||||
ArrayRef<int64_t> subDiagShape(subPivotShape.begin(),
|
||||
subPivotShape.end() - 1);
|
||||
auto subPivotTy = RankedTensorType::get(subPivotShape, elemTy);
|
||||
auto subDiagTy = RankedTensorType::get(subDiagShape, elemTy);
|
||||
Value rowPlusOne = b.create<arith::AddIOp>(loc, row, cstOne);
|
||||
offsets[inputRank - 2] = getAsOpFoldResult(rowPlusOne);
|
||||
sizes[inputRank - 2] = getAsOpFoldResult(
|
||||
b.create<arith::SubIOp>(loc, matDim, rowPlusOne));
|
||||
// offsets = [0, row + 1, row], sizes = [C, N - row - 1, 1] -> A_j,row
|
||||
// with j > row
|
||||
Value subDiag = b.create<tensor::ExtractSliceOp>(
|
||||
loc, subDiagTy, vals[0], offsets, sizes, strides);
|
||||
offsets.back() = cstZeroFold;
|
||||
sizes.back() = getAsOpFoldResult(matDim);
|
||||
// offsets = [0, row + 1, 0], sizes = [C, N - row - 1, N] -> elements
|
||||
// below pivot row
|
||||
Value subPivot = b.create<tensor::ExtractSliceOp>(
|
||||
loc, subPivotTy, vals[0], offsets, sizes, strides);
|
||||
Value initResult = b.create<tensor::EmptyOp>(loc, sizes, elemTy);
|
||||
// write a generic op to perform subpivot = subpivot -
|
||||
// (subdiag/diag)*pivot
|
||||
// d0 = batches, d1 = row, d2 = column -> pivot(d0,d2), diag(d0),
|
||||
// subPivot(d0,d1,d2), subDiag(d0, d1); output(d0,d1,d2)
|
||||
SmallVector<AffineExpr> allDims;
|
||||
for (unsigned i = 0; i < inputRank; i++)
|
||||
allDims.push_back(b.getAffineDimExpr(i));
|
||||
SmallVector<AffineExpr> rowIterator(1, allDims[0]);
|
||||
SmallVector<AffineExpr> colIterator;
|
||||
SmallVector<AffineExpr> batchIterator;
|
||||
if (isBatched) {
|
||||
rowIterator.push_back(allDims[1]);
|
||||
colIterator.push_back(allDims[0]);
|
||||
colIterator.push_back(allDims[2]);
|
||||
batchIterator.push_back(allDims[0]);
|
||||
} else {
|
||||
colIterator.push_back(allDims[1]);
|
||||
batchIterator.push_back(getAffineConstantExpr(0, context));
|
||||
}
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
indexingMaps.push_back(
|
||||
AffineMap::get(inputRank, 0, colIterator, context));
|
||||
indexingMaps.push_back(
|
||||
AffineMap::get(inputRank, 0, batchIterator, context));
|
||||
indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank));
|
||||
indexingMaps.push_back(
|
||||
AffineMap::get(inputRank, 0, rowIterator, context));
|
||||
indexingMaps.push_back(b.getMultiDimIdentityMap(inputRank));
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
inputRank, utils::IteratorType::parallel);
|
||||
Value reducedSubPivot =
|
||||
b.create<linalg::GenericOp>(
|
||||
loc, subPivotTy, ValueRange{pivot, diag, subPivot, subDiag},
|
||||
initResult, indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
// for d0 in batches, d1 in subpivotrows, d2 in columns
|
||||
// let i represent the pivot row index (scf loop index)
|
||||
Value pivotd0d2 = args[0];
|
||||
Value diagd0 = args[1];
|
||||
Value subPivotd0d1d2 = args[2];
|
||||
Value subDiagd0d1 = args[3];
|
||||
// coeff = A_d1,i / A_i,i
|
||||
Value coeff =
|
||||
b.create<arith::DivFOp>(loc, subDiagd0d1, diagd0);
|
||||
auto cmp = b.create<arith::CmpFOp>(
|
||||
loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF);
|
||||
b.create<cf::AssertOp>(
|
||||
loc, cmp,
|
||||
b.getStringAttr(
|
||||
"unimplemented: determinants requiring "
|
||||
"permutations and singular matrices"));
|
||||
// coeff*A_i,d2
|
||||
Value scaledPivotValue =
|
||||
b.create<arith::MulFOp>(loc, coeff, pivotd0d2);
|
||||
// result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2
|
||||
// so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0
|
||||
Value result = b.create<arith::SubFOp>(loc, subPivotd0d1d2,
|
||||
scaledPivotValue);
|
||||
b.create<linalg::YieldOp>(loc, result);
|
||||
})
|
||||
.getResult(0);
|
||||
Value rowReductionResult = b.create<tensor::InsertSliceOp>(
|
||||
loc, reducedSubPivot, vals[0], offsets, sizes, strides);
|
||||
b.create<scf::YieldOp>(loc,
|
||||
ValueRange{rowReductionResult, updatedDiags});
|
||||
});
|
||||
Value allDiagsExceptLast = rowReductionLoop.getResult(1);
|
||||
SmallVector<OpFoldResult> offsets(inputRank,
|
||||
getAsOpFoldResult(matDimMinusOne));
|
||||
SmallVector<OpFoldResult> strides(inputRank, getAsOpFoldResult(cstOne));
|
||||
SmallVector<OpFoldResult> sizes(inputRank, getAsOpFoldResult(cstOne));
|
||||
sizes[0] = getAsOpFoldResult(chDim);
|
||||
if (isBatched)
|
||||
offsets[0] = getAsOpFoldResult(cstZero);
|
||||
Value lastDiag = rewriter.create<tensor::ExtractSliceOp>(
|
||||
loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, strides);
|
||||
offsets.pop_back();
|
||||
strides.pop_back();
|
||||
sizes.pop_back();
|
||||
Value allDiags = rewriter.create<tensor::InsertSliceOp>(
|
||||
loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides);
|
||||
// linalg generic to do reduce prod for allDiags along back dim.
|
||||
// the result of that generic will be the determinant
|
||||
SmallVector<AffineMap> indexingMaps;
|
||||
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(inputRank - 1));
|
||||
AffineExpr resultExpr = isBatched ? rewriter.getAffineDimExpr(0)
|
||||
: getAffineConstantExpr(0, context);
|
||||
indexingMaps.push_back(AffineMap::get(inputRank - 1, 0, resultExpr));
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
inputRank - 1, utils::IteratorType::parallel);
|
||||
Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy,
|
||||
getConstant(rewriter, loc, 1.0, elemTy));
|
||||
Value determinant =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initDet.getType(), ValueRange{allDiags}, initDet,
|
||||
indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value prod = b.create<arith::MulFOp>(loc, args[0], args[1]);
|
||||
b.create<linalg::YieldOp>(loc, prod);
|
||||
})
|
||||
.getResult(0);
|
||||
Type newResultType =
|
||||
getTypeConverter()->convertType(op.getResult().getType());
|
||||
if (isBatched) {
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
determinant);
|
||||
return success();
|
||||
}
|
||||
Value detVal = rewriter.create<tensor::ExtractOp>(
|
||||
loc, determinant, SmallVector<Value>(1, cstZero));
|
||||
rewriter.replaceOpWithNewOp<tensor::FromElementsOp>(op, newResultType,
|
||||
ValueRange{detVal});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
|
@ -3009,4 +3222,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
|
||||
target.addIllegalOp<Aten__InterpolateSizeListScaleListOp>();
|
||||
patterns.add<ConvertInterpolateOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLinalgDetOp>();
|
||||
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -6485,6 +6485,68 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %int-2 = torch.constant.int -2\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
" %int2 = torch.constant.int 2\n"
|
||||
" %int3 = torch.constant.int 3\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||
" torch.prim.If.yield %true : !torch.bool\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %10 = torch.aten.eq.int %9, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If.yield %10 : !torch.bool\n"
|
||||
" }\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %4 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %5 = torch.aten.eq.int %3, %4 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" torch.prim.If %5 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" %6 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
|
||||
" %7 = torch.aten.eq.int %6, %int3 : !torch.int, !torch.int -> !torch.bool\n"
|
||||
" %8 = torch.prim.If %7 -> (!torch.list<int>) {\n"
|
||||
" %9 = torch.aten.slice.t %arg0, %none, %int1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %9 : !torch.list<int>\n"
|
||||
" } else {\n"
|
||||
" %9 = torch.derefine %arg0 : !torch.list<int> to !torch.any\n"
|
||||
" %10 = func.call @__torch__.torch.jit._shape_functions.zero_dim_tensor(%9) : (!torch.any) -> !torch.list<int>\n"
|
||||
" torch.prim.If.yield %10 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" return %8 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._linalg_det\"(%arg0: !torch.list<int>) -> !torch.tuple<list<int>, list<int>, list<int>> {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %int-1 = torch.constant.int -1\n"
|
||||
" %0 = call @\"__torch_mlir_shape_fn.aten.linalg_det\"(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" %1 = torch.aten.slice.t %arg0, %none, %int-1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
|
||||
" %2 = torch.prim.TupleConstruct %0, %arg0, %1 : !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" return %2 : !torch.tuple<list<int>, list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._linalg_det\"(%arg0: !torch.tuple<int, int>) -> !torch.tuple<int, int, int> {\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.int, !torch.int, !torch.int -> !torch.tuple<int, int, int>\n"
|
||||
" return %3 : !torch.tuple<int, int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.detach\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10986,6 +11048,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.linalg_det\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||
" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n"
|
||||
" torch.prim.If %2 -> () {\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" } else {\n"
|
||||
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||
" torch.prim.If.yield\n"
|
||||
" }\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.dropout\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.bool) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -2619,6 +2619,28 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
|
||||
class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten_LinalgDetOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
SmallVector<Value, 3> results = op.getResults();
|
||||
if (!results[1].use_empty() || !results[2].use_empty())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unsupported: _linalg_det results: LU and pivot");
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.getA();
|
||||
Value determinant = rewriter.create<Torch::AtenLinalgDetOp>(
|
||||
loc, results[0].getType(), input);
|
||||
rewriter.replaceAllUsesWith(results[0], determinant);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
|
||||
// prims.collapse operations.
|
||||
//
|
||||
|
@ -8701,6 +8723,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuIndicesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAten_LinalgDetOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<
|
||||
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
|
||||
// More specific conv ops
|
||||
|
|
|
@ -404,6 +404,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
|||
target.addIllegalOp<AtenMvOp>();
|
||||
target.addIllegalOp<AtenRenormOp>();
|
||||
target.addIllegalOp<AtenLinalgCrossOp>();
|
||||
target.addIllegalOp<Aten_LinalgDetOp>();
|
||||
target.addIllegalOp<AtenPixelShuffleOp>();
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
|
|
|
@ -559,6 +559,9 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ConvolutionBackwardModule2D_basic",
|
||||
"CumsumModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
"DeterminantModule_F32",
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
|
@ -2939,6 +2942,9 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
"DeterminantModule_F32",
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
|
@ -3734,6 +3740,10 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"CumsumStaticModule_basic",
|
||||
"CumsumStaticNegativeDimModule_basic",
|
||||
"DeformConv2D_basic",
|
||||
"DeterminantModule_F32",
|
||||
"DeterminantBatchedModule_F32",
|
||||
"DeterminantDynamicModule_F32",
|
||||
"DeterminantModule_F32",
|
||||
"DiagonalModule_basic",
|
||||
"DiagonalModule_nonsquare",
|
||||
"DiagonalModule_transposed",
|
||||
|
|
|
@ -223,6 +223,19 @@ def aten〇sign〡shape(self: List[int]) -> List[int]:
|
|||
def aten〇sgn〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇linalg_det〡shape(A: List[int]) -> List[int]:
|
||||
assert len(A) == 2 or len(A) == 3
|
||||
assert A[-1] == A[-2]
|
||||
if len(A) == 3:
|
||||
return A[:1]
|
||||
return upstream_shape_functions.zero_dim_tensor(A)
|
||||
|
||||
def aten〇_linalg_det〡shape(A: List[int]) -> Tuple[List[int], List[int], List[int]]:
|
||||
return (aten〇linalg_det〡shape(A), A, A[:-1])
|
||||
|
||||
def aten〇_linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> Tuple[int, int, int]:
|
||||
return (A_rank_dtype[1], A_rank_dtype[1], A_rank_dtype[1])
|
||||
|
||||
def aten〇detach〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2630,6 +2643,12 @@ def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4,4),], error_types={*all_integer_dtypes()}))
|
||||
def aten〇linalg_det〡dtype(A_rank_dtype: Tuple[int, int]) -> int:
|
||||
self_rank, self_dtype = A_rank_dtype
|
||||
assert not is_integer_dtype(self_dtype)
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, p=0.5, train=False))
|
||||
def aten〇dropout〡dtype(input_rank_dtype: Tuple[int, int], p: float, train: bool) -> int:
|
||||
input_rank, input_dtype = input_rank_dtype
|
||||
|
|
|
@ -699,6 +699,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)")
|
||||
emit("aten::linalg_det : (Tensor) -> (Tensor)")
|
||||
emit("aten::_linalg_det : (Tensor) -> (Tensor, Tensor, Tensor)")
|
||||
emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)")
|
||||
emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)")
|
||||
|
|
|
@ -43,6 +43,7 @@ def register_all_tests():
|
|||
from . import slice_like
|
||||
from . import nll_loss
|
||||
from . import index_select
|
||||
from . import linalg_algorithms
|
||||
from . import arange
|
||||
from . import constant_alloc
|
||||
from . import threshold
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
# Also available under a BSD-style license. See LICENSE.
|
||||
|
||||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.framework import TestUtils
|
||||
from torch_mlir_e2e_test.registry import register_test_case
|
||||
from torch_mlir_e2e_test.annotations import annotate_args, export
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class DeterminantModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(4, 4), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.det(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DeterminantModule())
|
||||
def DeterminantModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
||||
|
||||
class DeterminantBatchedModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(3, 4, 4), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.det(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DeterminantBatchedModule())
|
||||
def DeterminantBatchedModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
||||
|
||||
|
||||
class DeterminantDynamicModule(torch.nn.Module):
|
||||
@export
|
||||
@annotate_args([None, [(-1, -1, -1), torch.float32, True]])
|
||||
def forward(self, A):
|
||||
return torch.linalg.det(A)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: DeterminantBatchedModule())
|
||||
def DeterminantDynamicModule_F32(module, tu: TestUtils):
|
||||
A = tu.rand(3, 4, 4).to(dtype=torch.float32)
|
||||
module.forward(A)
|
Loading…
Reference in New Issue