mirror of https://github.com/llvm/torch-mlir
Add broadcast
parent
a459e09ab7
commit
8853dfbc74
|
@ -377,7 +377,23 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
|
|||
def forward(self, tensor):
|
||||
return self.softmax.forward(tensor)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module())
|
||||
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randn(3, 2, 4).double())
|
||||
|
||||
class BroadcastToModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, 1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.broadcast_to(x, [1, -1, -1, 4])
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: BroadcastToModule())
|
||||
def BroadcastToModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1, 1))
|
|
@ -1475,6 +1475,20 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
|
|||
let assemblyFormat = "$self `,` $size `,` $implicit attr-dict `:` type($self) `,` type($size) `,` type($implicit) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -2134,6 +2134,105 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenBroadcastToOp op, llvm::ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
AtenBroadcastToOp::Adaptor adaptor(operands);
|
||||
Value self = adaptor.self();
|
||||
auto selfType = self.getType().cast<RankedTensorType>();
|
||||
ArrayRef<int64_t> selfShape = selfType.getShape();
|
||||
Type elementType = selfType.getElementType();
|
||||
Location loc = op.getLoc();
|
||||
MLIRContext *context = op->getContext();
|
||||
|
||||
SmallVector<Value> inShape, outShape;
|
||||
if (!getListConstructElements(adaptor.size(), inShape)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the size list is not from list construct");
|
||||
}
|
||||
SmallVector<Value> inShapeConverted =
|
||||
getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape);
|
||||
if (inShape.size() < selfShape.size())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid shape: must not be smaller than rank of tensor");
|
||||
size_t diff = inShape.size() - selfShape.size();
|
||||
|
||||
// Create affine map and shapes for tensor initialization.
|
||||
SmallVector<AffineExpr> outExpr;
|
||||
Value zero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
for (size_t i = 0; i < inShape.size(); i++) {
|
||||
Value shapeValue = inShapeConverted[i];
|
||||
size_t j = i - diff;
|
||||
if (i < diff) {
|
||||
Value isValid = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::sge, shapeValue, zero);
|
||||
rewriter.create<AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"negative values not allowed in new dimensions"));
|
||||
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
|
||||
continue;
|
||||
}
|
||||
if (selfShape[j] == 1) {
|
||||
// Broadcast singleton dimension
|
||||
Value one =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value select = rewriter.create<SelectOp>(
|
||||
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
|
||||
outShape.push_back(select);
|
||||
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
|
||||
continue;
|
||||
}
|
||||
// Non-broadcast case
|
||||
Value dim = getDimOp(rewriter, loc, self, j);
|
||||
Value isNegative = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, shapeValue, zero);
|
||||
Value isEqual = rewriter.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
|
||||
shapeValue);
|
||||
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
|
||||
rewriter.create<AssertOp>(
|
||||
loc, isValid,
|
||||
rewriter.getStringAttr(
|
||||
"only broadcasting singleton dimensions supported"));
|
||||
outShape.push_back(dim);
|
||||
outExpr.push_back(mlir::getAffineDimExpr(i, context));
|
||||
}
|
||||
|
||||
Value outTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
|
||||
|
||||
SmallVector<AffineMap> indexingMaps = {
|
||||
AffineMap::get(inShape.size(), 0, outExpr, context),
|
||||
rewriter.getMultiDimIdentityMap(inShape.size())};
|
||||
SmallVector<StringRef> iteratorTypes(inShape.size(), "parallel");
|
||||
Value result = rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, outTensor.getType(), self, outTensor,
|
||||
indexingMaps, iteratorTypes,
|
||||
[](OpBuilder &b, Location loc, ValueRange args) {
|
||||
b.create<linalg::YieldOp>(loc, args[0]);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -2195,6 +2294,8 @@ public:
|
|||
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenLayerNormOp>();
|
||||
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenBroadcastToOp>();
|
||||
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenArgmaxOp>();
|
||||
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenSizeIntOp>();
|
||||
|
|
|
@ -90,7 +90,8 @@ public:
|
|||
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
|
||||
copyToValueTensorOps.push_back(copyToValueTensor);
|
||||
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp>(op)) {
|
||||
AtenTransposeIntOp, TensorStaticInfoCastOp>(op),
|
||||
AtenBroadcastToOp > (op)) {
|
||||
viewLikeOps.push_back(op);
|
||||
llvm::append_range(workList, op->getResult(0).getUsers());
|
||||
} else {
|
||||
|
|
|
@ -347,6 +347,8 @@ public:
|
|||
targetDim = size == -1 ? inputDim : size;
|
||||
};
|
||||
return visitExpandLikeOp(expand, expand.size(), operands, setDim);
|
||||
} else if (auto broadcast = dyn_cast<AtenBroadcastToOp>(op)) {
|
||||
return visitBroadcastToOp(broadcast, broadcast.size(), operands);
|
||||
} else if (auto repeat = dyn_cast<AtenRepeatOp>(op)) {
|
||||
// The repeats list specify the number of times to repeat along each dim
|
||||
// of the original tensor.
|
||||
|
@ -447,6 +449,9 @@ private:
|
|||
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
|
||||
SetDimSizePerListItemFn setDim);
|
||||
ChangeResult
|
||||
visitBroadcastToOp(Operation *op, Value list,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenCatOp(AtenCatOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
|
@ -997,6 +1002,20 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
|
|||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitBroadcastToOp(
|
||||
Operation *op, Value list,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
if (!input.hasSizes)
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
|
||||
fillInSizesGivenSizesList(knowledge, list);
|
||||
return getLatticeElement(op->getResult(0)).join(knowledge);
|
||||
}
|
||||
|
||||
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given
|
||||
// dimension. The output has the same sizes as the input for all dimensions
|
||||
// except the given dimension.
|
||||
|
|
|
@ -520,6 +520,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
|
||||
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
|
||||
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
|
||||
emit("aten::item : (Tensor) -> (Scalar)")
|
||||
|
|
Loading…
Reference in New Issue