E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode (#1066)

pull/1130/head
Vidush Singhal 2022-08-01 16:44:11 -04:00 committed by GitHub
parent 554570f3ab
commit ed13ebfd8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 436 additions and 0 deletions

View File

@ -5036,6 +5036,40 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
}];
}
def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$weight,
AnyTorchTensorType:$indices,
AnyTorchTensorType:$offsets,
Torch_BoolType:$scale_grad_by_freq,
Torch_IntType:$mode,
Torch_BoolType:$sparse,
AnyTorchOptionalTensorType:$per_sample_weights,
Torch_BoolType:$include_last_offset,
AnyTorchOptionalIntType:$padding_idx
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1,
AnyTorchTensorType:$result2,
AnyTorchTensorType:$result3
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenEmbeddingBagPaddingIdxOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 9, 4);
}
void AtenEmbeddingBagPaddingIdxOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 9, 4);
}
}];
}
def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -153,6 +153,13 @@ enum MemoryFormat {
//===----------------------------------------------------------------------===//
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
//===----------------------------------------------------------------------===//
// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops.
// Source:
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
//===-----------------------------------------------------------------------===//
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
} // namespace torch_upstream
} // namespace torch
} // namespace mlir

View File

@ -168,6 +168,264 @@ public:
};
} // namespace
namespace {
// AtenEmbeddingPaddingIdxOp
// SUM mode == integer 0
// Sums bags of embeddings together from a weight tensor based on an index and
// offset Vector. Example arguments weight = [[1, 3, 5, 3],
// [3, 4, 2, 1],
// [2, 2, 3, 2],
// [0, 4, 2, 1]]
//
// indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1]
// offsets = [0, 3, 5]
//
// output_tensor = initZeroTensor(offsets_length, embedding_size)
//
// for i in range(offsets_length): <- dim0
// for j in range(indices_length): <- dim1
// for k in range(embedding_size): <- dim2
// if(offsets[i] <= j and j < offsets[i+1]):
// output_tensor[i][k] = output_tensor[i][k] +
// weight[indices[j]][k]
// else:
// break
//
// Indexing maps for linalg::Generic ops
//
//
// indices_indexing_map = (d0, d1, d2) -> (d1)
// offset_indexing_map = (d0, d1, d2) -> (d0)
// output_indexing_map = (d0, d1, d2) -> (d0, d2)
//
// TODO: Find an optimal lowering.
// current lowering is not optimal for bags of large embeddings.
// Since it traverses the output tensor multiple times.
//
//
class ConvertAtenEmbeddingBagPaddingIdxOp
: public OpConversionPattern<AtenEmbeddingBagPaddingIdxOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
auto context = op->getContext();
Value weight = adaptor.weight();
Value indices = adaptor.indices();
Value offsets = adaptor.offsets();
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
Value mode = op.mode();
Value sparse = op.sparse();
Value includeLastOffset = op.include_last_offset();
int64_t modeInt;
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
return rewriter.notifyMatchFailure(
op, "mode is expected to be a constant integer value.");
}
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag.");
}
bool isSparse;
if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) {
return rewriter.notifyMatchFailure(
op, "sparse is expected to be a constant boolean value.");
}
if (isSparse) {
return rewriter.notifyMatchFailure(
op,
"Unimplemented: Sparse mode is not supported yet for EmbeddingBag.");
}
bool discardLastOffset;
if (!matchPattern(includeLastOffset,
m_TorchConstantBool(&discardLastOffset))) {
return rewriter.notifyMatchFailure(
op,
"include_last_offset is expected to be a constant boolean value.");
}
auto weightTy = weight.getType().cast<RankedTensorType>();
if (weightTy.getRank() != 2)
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
auto indicesTy = indices.getType().cast<RankedTensorType>();
if (indicesTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "indices must be a vector");
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
if (offsetsTy.getRank() != 1)
return rewriter.notifyMatchFailure(op, "offsets much be a vector");
Type weightElemTy = weightTy.getElementType();
int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank();
SmallVector<AffineExpr> indicesExpr;
indicesExpr.push_back(mlir::getAffineDimExpr(1, context));
auto indicesIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
indicesExpr, context);
SmallVector<AffineExpr> offsetsExpr;
offsetsExpr.push_back(mlir::getAffineDimExpr(0, context));
auto offsetIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
offsetsExpr, context);
SmallVector<AffineExpr> outputExpr;
outputExpr.push_back(mlir::getAffineDimExpr(0, context));
outputExpr.push_back(mlir::getAffineDimExpr(2, context));
auto outputIndexingMap =
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
outputExpr, context);
SmallVector<AffineMap, 3> indexingMaps = {
indicesIndexingMap,
offsetIndexingMap,
outputIndexingMap,
};
SmallVector<StringRef> iteratorTypes(iterationMapDimension,
getParallelIteratorTypeName());
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
Value initTensor;
Value offsetsLength;
Value indicesLength;
if (!discardLastOffset) {
SmallVector<Value> sizes{getDimOp(rewriter, loc, offsets, 0),
embeddingDim};
initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy);
offsetsLength = getDimOp(rewriter, loc, offsets, 0);
indicesLength = getDimOp(rewriter, loc, indices, 0);
} else {
return rewriter.notifyMatchFailure(
op, "Unimplemented: include last offset is not yet "
"supported for EmbeddingBag.");
}
Value embeddingBagResult =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), ValueRange{indices, offsets},
initTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value indexInIndices = args[0];
Value offsetsI = args[1];
Value initTensorElem = args[2];
Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
Value indexIToInt = castIndexToInt64(b, loc, indexI);
Value one = getConstant(
b, loc, 1,
mlir::IntegerType::get(getContext(), 64,
IntegerType::Signless));
Value offsetIndexPlusOneInt =
b.create<arith::AddIOp>(loc, indexIToInt, one);
Value offsetIndexPlusOne =
castIntToIndex(b, loc, offsetIndexPlusOneInt);
Value checkLast = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq,
castIndexToInt64(b, loc, offsetsLength),
offsetIndexPlusOneInt);
Value nextOffset = b.create<arith::SelectOp>(
loc, checkLast, castIndexToInt64(b, loc, indicesLength),
b.create<tensor::ExtractOp>(loc, offsets,
offsetIndexPlusOne));
Value indicesIndex = castIndexToInt64(
b, loc, b.create<linalg::IndexOp>(loc, /*value=*/1));
Value offsetLessThanIndicesIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex);
Value offsetEqualToIndicesIndex = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex);
Value offsetLessThanOrEqualToIndicesIndex =
b.create<arith::OrIOp>(loc, offsetLessThanIndicesIndex,
offsetEqualToIndicesIndex);
Value indicesIndexLessThanNextOffset =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
indicesIndex, nextOffset);
Value indicesIndexWithinBounds = b.create<arith::AndIOp>(
loc, offsetLessThanOrEqualToIndicesIndex,
indicesIndexLessThanNextOffset);
SmallVector<Value> indexIntoWeight;
indexIntoWeight.push_back(
castIntToIndex(b, loc, indexInIndices));
indexIntoWeight.push_back(
b.create<linalg::IndexOp>(loc, /*value=*/2));
Value weightElem = b.create<tensor::ExtractOp>(
loc, weight, indexIntoWeight);
Value addResult = b.create<arith::AddFOp>(loc, weightElem,
initTensorElem);
Value select =
b.create<arith::SelectOp>(loc, indicesIndexWithinBounds,
addResult, initTensorElem);
b.create<linalg::YieldOp>(loc, select);
})
.getResult(0);
// cast outputType.
auto restulType0 = typeConverter->convertType(op->getResult(0).getType());
Value castedEmbeddingBagResult =
rewriter.create<tensor::CastOp>(loc, restulType0, embeddingBagResult);
// offset2 tensor, this should be an empty tensor for the sum mode
SmallVector<Value> offsetResultSize;
Type offsetElemTy = offsetsTy.getElementType();
Value zeroDim = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/0);
offsetResultSize.push_back(zeroDim);
Value offsetResult = rewriter.create<linalg::InitTensorOp>(
loc, offsetResultSize, offsetElemTy);
auto resultType1 = typeConverter->convertType(op->getResult(1).getType());
Value castedOffsetResult =
rewriter.create<tensor::CastOp>(loc, resultType1, offsetResult);
SmallVector<Value> offsetSize = getTensorSizes(rewriter, loc, offsets);
// bagsize, vector of size offset with zeros, I think this is always just
// a vector of zeros in the sum mode
Value bagSize =
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
auto resultType2 = typeConverter->convertType(op->getResult(2).getType());
Value castedBagSizeResult =
rewriter.create<tensor::CastOp>(loc, resultType2, bagSize);
// max indices, vector of size offset with zeros, this is also always a
// vector of zeros in the sum mode. Its mainly used in the max mode.
Value indicesOut =
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
auto resultType3 = typeConverter->convertType(op->getResult(3).getType());
Value castedMaxIndices =
rewriter.create<tensor::CastOp>(loc, resultType3, indicesOut);
rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult,
castedBagSizeResult, castedMaxIndices});
return success();
}
};
} // namespace
namespace {
// Let's say we have an input tensor: initialized with some random values of
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
@ -468,4 +726,6 @@ void mlir::torch::torch_to_linalg::
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
target.addIllegalOp<AtenIndexTensorOp>();
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
}

View File

@ -1036,6 +1036,28 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
// case for Embedding bag padding idx.
if (auto embedding_bag_padding_idx =
dyn_cast<AtenEmbeddingBagPaddingIdxOp>(op)) {
auto resultFloatKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultFloatKnowledge.dtype = Float32Type::get(op->getContext());
incorporateKnowledge(embedding_bag_padding_idx.getResult(0),
resultFloatKnowledge);
auto resultIntKnowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
resultIntKnowledge.dtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
for (int64_t i = 1; i < 4; i++) {
incorporateKnowledge(embedding_bag_padding_idx.getResult(i),
resultIntKnowledge);
}
return;
}
if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
return;

View File

@ -6424,6 +6424,68 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional<list<int>>, %arg7: !torch.bool, %arg8: !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>> {
%none = torch.constant.none
%str = torch.constant.str "AssertionError: "
%int2 = torch.constant.int 2
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
%1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %1 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
%3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %3 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%4 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
%5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool
torch.prim.If %5 -> () {
torch.prim.If.yield
} else {
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
torch.prim.If.yield
}
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
%7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int
%8 = torch.prim.If %arg7 -> (!torch.int) {
%19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int
torch.prim.If.yield %19 : !torch.int
} else {
torch.prim.If.yield %7 : !torch.int
}
%9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int
%10 = torch.aten.append.t %6, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
%11 = torch.aten.append.t %6, %9 : !torch.list<int>, !torch.int -> !torch.list<int>
%12 = torch.prim.ListConstruct : () -> !torch.list<int>
%13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool
%14 = torch.prim.If %13 -> (!torch.list<int>) {
%19 = torch.aten.append.t %12, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
torch.prim.If.yield %12 : !torch.list<int>
} else {
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>
torch.prim.If.yield %19 : !torch.list<int>
}
%15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>
%16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool
%17 = torch.prim.If %16 -> (!torch.list<int>) {
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list<int>) -> !torch.list<int>
torch.prim.If.yield %19 : !torch.list<int>
} else {
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>
torch.prim.If.yield %19 : !torch.list<int>
}
%18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>
return %18 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>
}
func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {
%int-1 = torch.constant.int -1
%true = torch.constant.bool true

View File

@ -948,6 +948,34 @@ def atenindex_puthacked_twin(self: List[int], indices: List[List[int]], va
def atenembedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
def atenembedding_bagpadding_idx(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
assert len(weight) == 2
assert len(indices) == 1
assert len(offsets) == 1
output_bag_shape: List[int] = []
out_dim0 = offsets[0]
if (include_last_offset):
out_dim0 = out_dim0 - 1
out_dim1 = weight[1]
output_bag_shape.append(out_dim0)
output_bag_shape.append(out_dim1)
offset2bag_shape: List[int] = []
if mode == 1:
offset2bag_shape.append(0)
else:
offset2bag_shape = upstream_shape_functions._copy(indices)
bag_size_shape = upstream_shape_functions._copy(offsets)
max_indices_shape: List[int] = []
if mode == 2:
max_indices_shape = upstream_shape_functions._copy(output_bag_shape)
else:
max_indices_shape = upstream_shape_functions._copy(offsets)
return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape
@check_shape_function([
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.

View File

@ -424,6 +424,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)")
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")

View File

@ -2583,3 +2583,25 @@ class NumpyTRank0Module(torch.nn.Module):
@register_test_case(module_factory=lambda: NumpyTRank0Module())
def NumpyTRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32))
class AtenEmbeddingBagSumExample(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1], torch.int64, True),
([-1], torch.int64, True),
])
def forward(self, weight, indices, offsets):
return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None)
@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample())
def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
weight = torch.rand(100, 10)
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
module.forward(weight, indices, offsets)