mirror of https://github.com/llvm/torch-mlir
reimplement linear lowering torchToMhlo (#1524)
parent
15b249777b
commit
50b524546f
|
@ -373,10 +373,10 @@ public:
|
|||
auto lhsRank = lhsTy.getRank();
|
||||
auto rhsRank = rhsTy.getRank();
|
||||
|
||||
if (lhsRank != 2 && lhsRank != 3)
|
||||
return op.emitError("aten.Linear called but input rank not 2 or 3");
|
||||
if (rhsRank != 2 && rhsRank != 3)
|
||||
return op.emitError("aten.Linear called but weight rank not 2 or 3");
|
||||
if (lhsRank < 1)
|
||||
return op.emitError("aten.Linear called but input rank 0");
|
||||
if (rhsRank != 2)
|
||||
return op.emitError("aten.Linear called but weight rank not 2");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -406,33 +406,59 @@ public:
|
|||
|
||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||
rhsTy.getRank() - lhsTy.getRank());
|
||||
auto lhsRank = lhsTy.getRank();
|
||||
|
||||
auto loc = op->getLoc();
|
||||
Value dotLhs;
|
||||
SmallVector<Value> resultDims;
|
||||
// vector * matrix or matrix * matrix can directly use mhlo.dot_general
|
||||
if (lhsTy.getRank() <= 2) {
|
||||
dotLhs = lhs;
|
||||
} else {
|
||||
// Broadcast weight and then use bmm would lead to too much data copy,
|
||||
// and more compute, then decreace the performance
|
||||
// Instead, reshape input to 2-D tensor, then use dot to perform
|
||||
// matrix-matrix multiply, and finnaly reshape to the output shape,
|
||||
// would get better performance
|
||||
|
||||
// [x_1, x_2, ..., x_n, in_features] * [in_features, out_features]
|
||||
// -> [x_1 * x_2 * ... * x_n , in_features] * [in_features, out_features]
|
||||
auto dotLhsTy = RankedTensorType::get(
|
||||
{ShapedType::kDynamicSize, lhsTy.getShape()[lhsRank - 1]},
|
||||
lhsTy.getElementType());
|
||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
||||
options.dimSizeIndexBits);
|
||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
||||
auto nBatchDims = resultRank - 2;
|
||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||
|
||||
auto lhsResultDim = nBatchDims;
|
||||
auto rhsResultDim = nBatchDims + 1;
|
||||
auto lhsContractingDim = nBatchDims + 1;
|
||||
auto rhsContractingDim = nBatchDims;
|
||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||
Value numel = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(intType, 1));
|
||||
|
||||
for (int i = 0; i < lhsRank - 1; ++i) {
|
||||
Value dimValue = rewriter.create<tensor::DimOp>(loc, lhs, i);
|
||||
resultDims.push_back(dimValue);
|
||||
numel = rewriter.create<arith::MulIOp>(
|
||||
loc, numel,
|
||||
rewriter.create<arith::IndexCastOp>(loc, intType, dimValue));
|
||||
}
|
||||
Value lhsLastRankDim = rewriter.create<arith::IndexCastOp>(
|
||||
loc, intType, rewriter.create<tensor::DimOp>(loc, lhs, lhsRank - 1));
|
||||
resultDims.push_back(rewriter.create<tensor::DimOp>(loc, rhs, 1));
|
||||
Value reshapeDim =
|
||||
rewriter
|
||||
.create<mlir::tensor::FromElementsOp>(
|
||||
op->getLoc(), ValueRange{numel, lhsLastRankDim})
|
||||
.getResult();
|
||||
dotLhs = rewriter.create<mhlo::DynamicReshapeOp>(loc, dotLhsTy, lhs,
|
||||
reshapeDim);
|
||||
}
|
||||
Value matmulOutput =
|
||||
rewriter.create<mhlo::DotOp>(loc, dotLhs, rhs, nullptr);
|
||||
auto outTy =
|
||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||
lhsContractingDim, rhsContractingDim);
|
||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||
mhlo::DotDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*lhsBatchingDimensions=*/batchDims,
|
||||
/*rhsBatchingDimensions=*/batchDims,
|
||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||
// reshape to [x_1, x_2, ..., x_n, out_features]
|
||||
if (dotLhs != lhs) {
|
||||
matmulOutput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, outTy, matmulOutput,
|
||||
rewriter.create<mlir::tensor::FromElementsOp>(loc, resultDims));
|
||||
}
|
||||
|
||||
Value matmulPlusBias = matmulOutput;
|
||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||
|
@ -443,9 +469,7 @@ public:
|
|||
.getResult();
|
||||
}
|
||||
|
||||
auto resultTy =
|
||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
|
||||
rewriter.replaceOp(op, matmulPlusBias);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -499,3 +499,60 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7
|
|||
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
|
||||
return %3 : !torch.vtensor<[1,4,15,15],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.linear(
|
||||
// CHECK-NOT: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.transpose
|
||||
// CHECK: mhlo.dot
|
||||
// CHECK: chlo.broadcast_add
|
||||
func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32>
|
||||
return %1 : !torch.vtensor<[4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.linear$nobias(
|
||||
// CHECK-NOT: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.transpose
|
||||
// CHECK: mhlo.dot
|
||||
// CHECK-NOT: chlo.broadcast_add
|
||||
func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||
%none = torch.constant.none
|
||||
%1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32>
|
||||
return %1 : !torch.vtensor<[4,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.linear$dynamic(
|
||||
// CHECK: mhlo.transpose
|
||||
// CHECK: arith.muli
|
||||
// CHECK: arith.muli
|
||||
// CHECK: tensor.from_elements
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.dot
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: chlo.broadcast_add
|
||||
func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32>
|
||||
return %1 : !torch.vtensor<[?,?,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D(
|
||||
// CHECK: mhlo.transpose
|
||||
// CHECK: arith.muli
|
||||
// CHECK: arith.muli
|
||||
// CHECK: tensor.from_elements
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: mhlo.dot
|
||||
// CHECK: mhlo.dynamic_reshape
|
||||
// CHECK: chlo.broadcast_add
|
||||
func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> {
|
||||
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32>
|
||||
return %1 : !torch.vtensor<[?,?,?,5],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue