mirror of https://github.com/llvm/torch-mlir
[TOSA] Fix broadcast_to input and output different shape support (#1855)
parent
83534370c3
commit
cc819e73dd
|
@ -679,6 +679,10 @@ TOSA_PASS_SET = {
|
|||
"HardsigmoidRandomModule_basic",
|
||||
"HardswishModule_basic",
|
||||
"HardswishRandomModule_basic",
|
||||
"FullLikeModuleInt2DStatic_basic",
|
||||
"FullModuleInt3D_basic",
|
||||
"FullModuleFloat2D_basic",
|
||||
"RepeatModule_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -3144,31 +3144,70 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
|||
op, "Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> outShape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape)))
|
||||
SmallVector<int64_t> resultShape;
|
||||
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"size must consist of Scalar constants");
|
||||
// Get the result type
|
||||
auto resultType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
SmallVector<int64_t> inputShape(
|
||||
makeShapeTorchCompatible(selfType.getShape()));
|
||||
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
|
||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||
// true then we can replace the op result with the input operand
|
||||
// irrespective of the users of the op result.
|
||||
if (!llvm::equal(inputShape, outShape)) {
|
||||
for (auto user : op->getResult(0).getUsers()) {
|
||||
// This case is only supported if the result of the `broadcast_to` op is
|
||||
// not used by an op which is a view like.
|
||||
if (isViewLikeOp(user)) {
|
||||
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
|
||||
// true then we can replace the op result with the input operand directly.
|
||||
if (llvm::equal(inputShape, resultShape)) {
|
||||
// If we reach here, then it means that the broadcasting is not required
|
||||
// since the input and result are of same shape.
|
||||
op.replaceAllUsesWith(op.getSelf());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
} else if (selfType.hasRank() &&
|
||||
(selfType.getRank() == (int64_t)resultShape.size() ||
|
||||
selfType.getRank() == 0)) {
|
||||
// Right now to support limited cases where input and result shape are not
|
||||
// equal, we can put a constraint that either the input should be of rank
|
||||
// 0 or the rank of input tensor and result should be equal. And then we
|
||||
// can check for broadcasting compatibility for the latter case. For
|
||||
// broadcasting compatibility, either the shape of input and result should
|
||||
// be equal at each dimenion or one of them should be 1.
|
||||
if (selfType.getRank() != 0) {
|
||||
for (unsigned i = 0; i < inputShape.size(); i++) {
|
||||
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
|
||||
resultShape[i] != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: broadcast not supported for this case");
|
||||
op, "unimplemented: either the shape of input and result should "
|
||||
"be equal at each dimenion or one of them should be 1.");
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we reach here, then it means the given case is handled by implicit
|
||||
// broadcasting done by tosa.
|
||||
op.replaceAllUsesWith(op.getSelf());
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
// If the above condition hold true then we can directly create a const
|
||||
// zero tensor of shape same as the result shape.
|
||||
SmallVector<int64_t> zeroTensorShape{resultShape};
|
||||
|
||||
// create the 0 constant tensor
|
||||
int64_t totalNumElements = 1;
|
||||
for (auto dimSize : zeroTensorShape) {
|
||||
totalNumElements = dimSize * totalNumElements;
|
||||
}
|
||||
// There is some danger here. For edge cases in floating point, x + 0 != x.
|
||||
// The cases are denormalized values, which may get flushed, and -0 + 0 =
|
||||
// +0. (sign bit flips). These are probably acceptable in the short term,
|
||||
// but we should put a comment acknowledging the danger, as there isn't an
|
||||
// op that avoids the denorm flushing.
|
||||
SmallVector<int64_t> intValues(totalNumElements, 0);
|
||||
SmallVector<float> floatValues(totalNumElements, 0.0);
|
||||
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
|
||||
? tosa::getConstTensor<float>(
|
||||
rewriter, op, floatValues, zeroTensorShape)
|
||||
.value()
|
||||
: tosa::getConstTensor<int64_t>(
|
||||
rewriter, op, intValues, zeroTensorShape)
|
||||
.value();
|
||||
|
||||
// Use add broadcast
|
||||
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
|
||||
zeroTensor);
|
||||
return success();
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
|
Loading…
Reference in New Issue