mirror of https://github.com/llvm/torch-mlir
[LINALG] Lower `aten.Matmul` to `linalg.BatchMatmul`
This commit lowers `aten.matmul` to `linalg.BatchMatmul` under the following conditions: 1. The result of matrix multiplication must have batch dimensions, i.e., rank greater than 2. 2. The resultant matrix must have at most 1 dynamic batch dimension. It also handles broadcasting of batch dimensions when batch dimensions of the matrices are broadcastable. Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>pull/977/merge snapshot-20220625.514
parent
f774e63abd
commit
1be604bfd3
|
@ -61,6 +61,7 @@ TOSA_PASS_SET = {
|
|||
"ElementwisePowModule_basic",
|
||||
"BmmModule_basic",
|
||||
"MmDagModule_basic",
|
||||
"Matmul4dStatic_basic",
|
||||
"Matmul_dot",
|
||||
"Matmul_3d",
|
||||
"RsubFloatModule_basic",
|
||||
|
|
|
@ -890,90 +890,6 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
static LogicalResult broadcastToGivenShape(Operation *op,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Value input,
|
||||
SmallVector<Value> broadcastToShape,
|
||||
Value &result) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
if (broadcastToShape.size() < inputShape.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
||||
"size of the input shape");
|
||||
}
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
|
||||
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
|
||||
result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), input, outTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
|
@ -995,8 +911,8 @@ public:
|
|||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||
|
||||
Value result;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, self, inShapeConverted,
|
||||
result))) {
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, self, inShapeConverted, result))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
@ -1060,8 +976,8 @@ public:
|
|||
for (unsigned i = 0; i < selfSizes.size(); i++)
|
||||
selfSizes[i] = castIndexToInt64(rewriter, loc, selfSizes[i]);
|
||||
Value broadcastedSrc;
|
||||
if (failed(broadcastToGivenShape(op, rewriter, src, selfSizes,
|
||||
broadcastedSrc))) {
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, src, selfSizes, broadcastedSrc))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
|
|
@ -164,12 +164,16 @@ public:
|
|||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
||||
|
||||
unsigned lhsRank = lhs.getType().cast<RankedTensorType>().getRank();
|
||||
unsigned rhsRank = rhs.getType().cast<RankedTensorType>().getRank();
|
||||
// Get the rank of both matrix.
|
||||
unsigned lhsRank = lhsType.getRank();
|
||||
unsigned rhsRank = rhsType.getRank();
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
Type elementType = newResultType.cast<TensorType>().getElementType();
|
||||
auto resultType = newResultType.cast<RankedTensorType>();
|
||||
Type elementType = resultType.getElementType();
|
||||
|
||||
// The different cases of torch_matmul op is mentioned here:
|
||||
// https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
|
@ -228,39 +232,134 @@ public:
|
|||
}
|
||||
|
||||
// Fourth Case: Batch-Matrix Multiplication.
|
||||
// TODO: Broadcasting of batch dimension is remaining.
|
||||
if (lhsRank >= 3 && rhsRank >= 3 && lhsRank == rhsRank) {
|
||||
// TODO: Handle batch matrix multiplication when one of the matrix is unity
|
||||
// rank and the other has batch dimension.
|
||||
if (lhsRank > 1 && rhsRank > 1) {
|
||||
unsigned maxRank = std::max(lhsRank, rhsRank);
|
||||
unsigned minRank = std::min(lhsRank, rhsRank);
|
||||
unsigned batchRank = maxRank - 2;
|
||||
|
||||
unsigned batchRank = lhsRank - 2;
|
||||
SmallVector<Value, 4> resultShape;
|
||||
// At least one of the matrix must have rank greater than 2.
|
||||
if (batchRank <= 0) {
|
||||
return rewriter.notifyMatchFailure(op, "expected batch dimensions");
|
||||
}
|
||||
|
||||
// The `broadcastedBatchShape` contains batch dimensions of the resultant
|
||||
// matrix.
|
||||
SmallVector<Value> broadcastedBatchShape(batchRank);
|
||||
Value maxRankMatrix = (lhsRank > rhsRank) ? lhs : rhs;
|
||||
Value maxDim;
|
||||
// Compute broadcasted batch dimensions if the batch dimensions of
|
||||
// the matrices are broadcastable.
|
||||
for (unsigned i = 1; i <= batchRank; i++) {
|
||||
if (i <= minRank - 2) {
|
||||
Value lhsDim = getDimOp(rewriter, loc, lhs, lhsRank - 2 - i);
|
||||
Value rhsDim = getDimOp(rewriter, loc, rhs, rhsRank - 2 - i);
|
||||
maxDim = rewriter.createOrFold<arith::MaxUIOp>(loc, lhsDim, rhsDim);
|
||||
} else {
|
||||
maxDim = getDimOp(rewriter, loc, maxRankMatrix, maxRank - 2 - i);
|
||||
}
|
||||
broadcastedBatchShape[batchRank - i] = maxDim;
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, lhsRank - 2);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, lhsRank - 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, rhsRank - 2);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, rhsRank - 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
// Compute broadcasted shape of both the matrices in integer format.
|
||||
SmallVector<Value> lhsBroadcastToShape(broadcastedBatchShape);
|
||||
lhsBroadcastToShape.push_back(lhsDim0);
|
||||
lhsBroadcastToShape.push_back(lhsDim1);
|
||||
SmallVector<Value> rhsBroadcastToShape(broadcastedBatchShape);
|
||||
rhsBroadcastToShape.push_back(rhsDim0);
|
||||
rhsBroadcastToShape.push_back(rhsDim1);
|
||||
for (unsigned i = 0; i < maxRank; i++) {
|
||||
lhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, lhsBroadcastToShape[i]);
|
||||
rhsBroadcastToShape[i] =
|
||||
castIndexToInt64(rewriter, loc, rhsBroadcastToShape[i]);
|
||||
}
|
||||
|
||||
// Broadcast the batch dimensions of both the matrices.
|
||||
Value broadcastedLhs, broadcastedRhs;
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, lhs, lhsBroadcastToShape, broadcastedLhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||
op, rewriter, rhs, rhsBroadcastToShape, broadcastedRhs))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unable to perform broadcast operation");
|
||||
}
|
||||
|
||||
// Check if the result of the matrix multiplication has more than one
|
||||
// dynamic batch dimensions.
|
||||
ArrayRef<int64_t> batchDimsInt = resultType.getShape().drop_back(2);
|
||||
bool multipleDynamicBatchDims =
|
||||
llvm::count(batchDimsInt, kUnknownSize) > 1;
|
||||
|
||||
// TODO: Lowering to `linalg.BatchMatmul` is only possible when there is
|
||||
// at most one dynamic batch dimension due to limited support of the
|
||||
// `tensor.ExpandShape` op.
|
||||
if (!multipleDynamicBatchDims) {
|
||||
// Collapse the batch dimensions into one dimension. The resultant rank
|
||||
// will always be 3.
|
||||
SmallVector<ReassociationIndices> reassociation(3);
|
||||
for (unsigned i = 0, j = 0; i < maxRank; i++) {
|
||||
if (i >= batchRank)
|
||||
j++;
|
||||
reassociation[j].push_back(i);
|
||||
}
|
||||
Value collapsedLhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedLhs, reassociation);
|
||||
Value collapsedRhs = rewriter.create<tensor::CollapseShapeOp>(
|
||||
op->getLoc(), broadcastedRhs, reassociation);
|
||||
|
||||
// Compute the result shape after collapsing the batch dimensions.
|
||||
SmallVector<Value> collapsedResultShape;
|
||||
collapsedResultShape.push_back(broadcastedBatchShape[0]);
|
||||
for (unsigned i = 1; i < batchRank; i++) {
|
||||
collapsedResultShape[0] = rewriter.createOrFold<arith::MulIOp>(
|
||||
loc, collapsedResultShape[0], broadcastedBatchShape[i]);
|
||||
}
|
||||
collapsedResultShape.push_back(lhsDim0);
|
||||
collapsedResultShape.push_back(rhsDim1);
|
||||
SmallVector<OpFoldResult> updatedCollapseResultShape =
|
||||
getAsOpFoldResult(collapsedResultShape);
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, updatedCollapseResultShape, elementType);
|
||||
Value c0 = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getZeroAttr(elementType));
|
||||
Value zeroTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||
|
||||
Value batchMatMul =
|
||||
rewriter
|
||||
.create<linalg::BatchMatmulOp>(
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{collapsedLhs, collapsedRhs}, zeroTensor)
|
||||
.getResult(0);
|
||||
Value expandResult = rewriter.create<tensor::ExpandShapeOp>(
|
||||
loc, resultType, batchMatMul, reassociation);
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType,
|
||||
expandResult);
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<AffineExpr> lhsExpr;
|
||||
SmallVector<AffineExpr> rhsExpr;
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
SmallVector<StringRef> iteratorTypes;
|
||||
|
||||
// Since broadcasting is a TODO, check whether the lhs and rhs batch
|
||||
// dimension match.
|
||||
for (unsigned i = 0; i < batchRank; i++) {
|
||||
Value lhsBatch = getDimOp(rewriter, loc, lhs, i);
|
||||
Value rhsBatch = getDimOp(rewriter, loc, rhs, i);
|
||||
resultShape.push_back(lhsBatch);
|
||||
lhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
rhsExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
outExpr.push_back(rewriter.getAffineDimExpr(i));
|
||||
iteratorTypes.push_back(getParallelIteratorTypeName());
|
||||
checkDimEqualHelper(rewriter, loc, lhsBatch, rhsBatch);
|
||||
}
|
||||
|
||||
Value lhsDim0 = getDimOp(rewriter, loc, lhs, batchRank);
|
||||
Value lhsDim1 = getDimOp(rewriter, loc, lhs, batchRank + 1);
|
||||
Value rhsDim0 = getDimOp(rewriter, loc, rhs, batchRank);
|
||||
Value rhsDim1 = getDimOp(rewriter, loc, rhs, batchRank + 1);
|
||||
checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0);
|
||||
|
||||
// Push the final matrix dimension.
|
||||
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||
|
||||
lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 1)});
|
||||
rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(batchRank + 1),
|
||||
|
@ -268,9 +367,10 @@ public:
|
|||
outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(batchRank),
|
||||
rewriter.getAffineDimExpr(batchRank + 2)});
|
||||
|
||||
Value initTensor0 =
|
||||
SmallVector<Value> resultShape(broadcastedBatchShape);
|
||||
resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1});
|
||||
Value zeroTensor =
|
||||
createZeroInitTensor(rewriter, loc, resultShape, elementType);
|
||||
|
||||
auto indexingMaps =
|
||||
AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr});
|
||||
iteratorTypes.insert(iteratorTypes.end(),
|
||||
|
@ -279,7 +379,8 @@ public:
|
|||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0,
|
||||
loc, zeroTensor.getType(),
|
||||
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
|
|
|
@ -256,3 +256,85 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
|||
iteratorTypes, bodyBuild)
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult torch_to_linalg::broadcastToGivenShape(
|
||||
Operation *op, PatternRewriter &rewriter, Value input,
|
||||
SmallVector<Value> broadcastToShape, Value &result) {
|
||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
if (broadcastToShape.size() < inputShape.size()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: broadcastToShape size must not be smaller than the "
|
||||
"size of the input shape");
|
||||
}
|
||||
|
||||
Type elementType = inputType.getElementType();
|
||||
Location loc = op->getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
SmallVector<Value> outShape;
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
size_t diff = broadcastToShape.size() - inputShape.size();
|
||||
for (size_t i = 0; i < broadcastToShape.size(); i++) {
|
||||
Value shapeValue = broadcastToShape[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (inputShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<arith::SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, input, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt64(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<cf::AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(broadcastToShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(broadcastToShape.size())};
|
||||
SmallVector<StringRef> iteratorTypes(broadcastToShape.size(), "parallel");
|
||||
result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), input, outTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -54,6 +54,13 @@ Value createElementwiseLinalgGeneric(
|
|||
OpBuilder &b, Location loc, ValueRange tensorOperands,
|
||||
Type resultElementType,
|
||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild);
|
||||
|
||||
// Broadcasts input tensor based on the broadcastToShape.
|
||||
LogicalResult broadcastToGivenShape(Operation *op, PatternRewriter &rewriter,
|
||||
Value input,
|
||||
SmallVector<Value> broadcastToShape,
|
||||
Value &result);
|
||||
|
||||
} // namespace torch_to_linalg
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -1129,8 +1129,10 @@ public:
|
|||
SmallVector<int32_t> transposedLhsDims;
|
||||
|
||||
// Step: generate the common dim/shape information
|
||||
bool hasDynamicDims = false;
|
||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||
hasDynamicDims |= isDynamicDim;
|
||||
if (isDynamicDim ||
|
||||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
|
||||
commonValue *= lhsBroadcastedShape[dim];
|
||||
|
@ -1138,11 +1140,13 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: Handle the case when there are dynamic batch dimensions.
|
||||
if (hasDynamicDims)
|
||||
commonValue = ShapedType::kDynamicSize;
|
||||
|
||||
// Step: generate the LHS squeezed dim/shape information.
|
||||
bool hasDynamicDims = false;
|
||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||
hasDynamicDims |= isDynamicDim;
|
||||
if (!isDynamicDim &&
|
||||
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
|
||||
lhsSqueezedValue *= lhsBroadcastedShape[dim];
|
||||
|
|
|
@ -130,3 +130,82 @@ def Matmul_4d(module, tu: TestUtils):
|
|||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class Matmul4dStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 5, 6, 7], torch.float32, True),
|
||||
([4, 5, 7, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: Matmul4dStatic())
|
||||
def Matmul4dStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulStaticBroadcast(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, 1, 6, 7], torch.float32, True),
|
||||
([8, 1, 5, 7, 6], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
|
||||
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulSingleDynamicBatchDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim())
|
||||
def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class MatmulBroadcastBatchDim(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([4, -1, -1, -1], torch.float32, True),
|
||||
([-1, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs, rhs):
|
||||
return torch.matmul(lhs, rhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MatmulBroadcastBatchDim())
|
||||
def MatmulBroadcastBatchDim_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6))
|
||||
|
||||
|
|
Loading…
Reference in New Issue