mirror of https://github.com/llvm/torch-mlir
E2E support for AtenEmbeddingBagPaddingIdxOp SUM Mode (#1066)
parent
554570f3ab
commit
ed13ebfd8d
|
@ -5036,6 +5036,40 @@ def Torch_AtenEmbeddingOp : Torch_Op<"aten.embedding", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenEmbeddingBagPaddingIdxOp : Torch_Op<"aten.embedding_bag.padding_idx", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchTensorType:$indices,
|
||||
AnyTorchTensorType:$offsets,
|
||||
Torch_BoolType:$scale_grad_by_freq,
|
||||
Torch_IntType:$mode,
|
||||
Torch_BoolType:$sparse,
|
||||
AnyTorchOptionalTensorType:$per_sample_weights,
|
||||
Torch_BoolType:$include_last_offset,
|
||||
AnyTorchOptionalIntType:$padding_idx
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result0,
|
||||
AnyTorchTensorType:$result1,
|
||||
AnyTorchTensorType:$result2,
|
||||
AnyTorchTensorType:$result3
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenEmbeddingBagPaddingIdxOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 9, 4);
|
||||
}
|
||||
void AtenEmbeddingBagPaddingIdxOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 9, 4);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -153,6 +153,13 @@ enum MemoryFormat {
|
|||
//===----------------------------------------------------------------------===//
|
||||
enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Possible value for `EmbeddingBag Mode` argument for Embedding bag ops.
|
||||
// Source:
|
||||
// https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h
|
||||
//===-----------------------------------------------------------------------===//
|
||||
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
|
||||
|
||||
} // namespace torch_upstream
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -168,6 +168,264 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// AtenEmbeddingPaddingIdxOp
|
||||
// SUM mode == integer 0
|
||||
// Sums bags of embeddings together from a weight tensor based on an index and
|
||||
// offset Vector. Example arguments weight = [[1, 3, 5, 3],
|
||||
// [3, 4, 2, 1],
|
||||
// [2, 2, 3, 2],
|
||||
// [0, 4, 2, 1]]
|
||||
//
|
||||
// indices = [0, 2, 3, 1, 2, 3, 2, 1, 0, 1]
|
||||
// offsets = [0, 3, 5]
|
||||
//
|
||||
// output_tensor = initZeroTensor(offsets_length, embedding_size)
|
||||
//
|
||||
// for i in range(offsets_length): <- dim0
|
||||
// for j in range(indices_length): <- dim1
|
||||
// for k in range(embedding_size): <- dim2
|
||||
// if(offsets[i] <= j and j < offsets[i+1]):
|
||||
// output_tensor[i][k] = output_tensor[i][k] +
|
||||
// weight[indices[j]][k]
|
||||
// else:
|
||||
// break
|
||||
//
|
||||
// Indexing maps for linalg::Generic ops
|
||||
//
|
||||
//
|
||||
// indices_indexing_map = (d0, d1, d2) -> (d1)
|
||||
// offset_indexing_map = (d0, d1, d2) -> (d0)
|
||||
// output_indexing_map = (d0, d1, d2) -> (d0, d2)
|
||||
//
|
||||
// TODO: Find an optimal lowering.
|
||||
// current lowering is not optimal for bags of large embeddings.
|
||||
// Since it traverses the output tensor multiple times.
|
||||
//
|
||||
//
|
||||
|
||||
class ConvertAtenEmbeddingBagPaddingIdxOp
|
||||
: public OpConversionPattern<AtenEmbeddingBagPaddingIdxOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenEmbeddingBagPaddingIdxOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op->getLoc();
|
||||
auto context = op->getContext();
|
||||
Value weight = adaptor.weight();
|
||||
Value indices = adaptor.indices();
|
||||
Value offsets = adaptor.offsets();
|
||||
Value scaleGradByFreq = adaptor.scale_grad_by_freq();
|
||||
Value mode = op.mode();
|
||||
Value sparse = op.sparse();
|
||||
Value includeLastOffset = op.include_last_offset();
|
||||
|
||||
int64_t modeInt;
|
||||
if (!matchPattern(mode, m_TorchConstantInt(&modeInt))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "mode is expected to be a constant integer value.");
|
||||
}
|
||||
|
||||
if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag.");
|
||||
}
|
||||
|
||||
bool isSparse;
|
||||
if (!matchPattern(sparse, m_TorchConstantBool(&isSparse))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "sparse is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
if (isSparse) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Unimplemented: Sparse mode is not supported yet for EmbeddingBag.");
|
||||
}
|
||||
|
||||
bool discardLastOffset;
|
||||
if (!matchPattern(includeLastOffset,
|
||||
m_TorchConstantBool(&discardLastOffset))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"include_last_offset is expected to be a constant boolean value.");
|
||||
}
|
||||
|
||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
||||
if (weightTy.getRank() != 2)
|
||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||
|
||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
||||
if (indicesTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "indices must be a vector");
|
||||
|
||||
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
|
||||
if (offsetsTy.getRank() != 1)
|
||||
return rewriter.notifyMatchFailure(op, "offsets much be a vector");
|
||||
|
||||
Type weightElemTy = weightTy.getElementType();
|
||||
|
||||
int64_t iterationMapDimension = weightTy.getRank() + indicesTy.getRank();
|
||||
SmallVector<AffineExpr> indicesExpr;
|
||||
indicesExpr.push_back(mlir::getAffineDimExpr(1, context));
|
||||
auto indicesIndexingMap =
|
||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||
indicesExpr, context);
|
||||
|
||||
SmallVector<AffineExpr> offsetsExpr;
|
||||
offsetsExpr.push_back(mlir::getAffineDimExpr(0, context));
|
||||
|
||||
auto offsetIndexingMap =
|
||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||
offsetsExpr, context);
|
||||
|
||||
SmallVector<AffineExpr> outputExpr;
|
||||
outputExpr.push_back(mlir::getAffineDimExpr(0, context));
|
||||
outputExpr.push_back(mlir::getAffineDimExpr(2, context));
|
||||
|
||||
auto outputIndexingMap =
|
||||
AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0,
|
||||
outputExpr, context);
|
||||
|
||||
SmallVector<AffineMap, 3> indexingMaps = {
|
||||
indicesIndexingMap,
|
||||
offsetIndexingMap,
|
||||
outputIndexingMap,
|
||||
};
|
||||
|
||||
SmallVector<StringRef> iteratorTypes(iterationMapDimension,
|
||||
getParallelIteratorTypeName());
|
||||
|
||||
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
|
||||
Value initTensor;
|
||||
Value offsetsLength;
|
||||
Value indicesLength;
|
||||
if (!discardLastOffset) {
|
||||
SmallVector<Value> sizes{getDimOp(rewriter, loc, offsets, 0),
|
||||
embeddingDim};
|
||||
|
||||
initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy);
|
||||
offsetsLength = getDimOp(rewriter, loc, offsets, 0);
|
||||
indicesLength = getDimOp(rewriter, loc, indices, 0);
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unimplemented: include last offset is not yet "
|
||||
"supported for EmbeddingBag.");
|
||||
}
|
||||
|
||||
Value embeddingBagResult =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, initTensor.getType(), ValueRange{indices, offsets},
|
||||
initTensor,
|
||||
/*indexingMaps=*/indexingMaps,
|
||||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value indexInIndices = args[0];
|
||||
Value offsetsI = args[1];
|
||||
Value initTensorElem = args[2];
|
||||
|
||||
Value indexI = b.create<linalg::IndexOp>(loc, /*value=*/0);
|
||||
Value indexIToInt = castIndexToInt64(b, loc, indexI);
|
||||
Value one = getConstant(
|
||||
b, loc, 1,
|
||||
mlir::IntegerType::get(getContext(), 64,
|
||||
IntegerType::Signless));
|
||||
Value offsetIndexPlusOneInt =
|
||||
b.create<arith::AddIOp>(loc, indexIToInt, one);
|
||||
|
||||
Value offsetIndexPlusOne =
|
||||
castIntToIndex(b, loc, offsetIndexPlusOneInt);
|
||||
Value checkLast = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq,
|
||||
castIndexToInt64(b, loc, offsetsLength),
|
||||
offsetIndexPlusOneInt);
|
||||
Value nextOffset = b.create<arith::SelectOp>(
|
||||
loc, checkLast, castIndexToInt64(b, loc, indicesLength),
|
||||
b.create<tensor::ExtractOp>(loc, offsets,
|
||||
offsetIndexPlusOne));
|
||||
|
||||
Value indicesIndex = castIndexToInt64(
|
||||
b, loc, b.create<linalg::IndexOp>(loc, /*value=*/1));
|
||||
|
||||
Value offsetLessThanIndicesIndex = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex);
|
||||
Value offsetEqualToIndicesIndex = b.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex);
|
||||
Value offsetLessThanOrEqualToIndicesIndex =
|
||||
b.create<arith::OrIOp>(loc, offsetLessThanIndicesIndex,
|
||||
offsetEqualToIndicesIndex);
|
||||
|
||||
Value indicesIndexLessThanNextOffset =
|
||||
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
indicesIndex, nextOffset);
|
||||
|
||||
Value indicesIndexWithinBounds = b.create<arith::AndIOp>(
|
||||
loc, offsetLessThanOrEqualToIndicesIndex,
|
||||
indicesIndexLessThanNextOffset);
|
||||
|
||||
SmallVector<Value> indexIntoWeight;
|
||||
indexIntoWeight.push_back(
|
||||
castIntToIndex(b, loc, indexInIndices));
|
||||
indexIntoWeight.push_back(
|
||||
b.create<linalg::IndexOp>(loc, /*value=*/2));
|
||||
Value weightElem = b.create<tensor::ExtractOp>(
|
||||
loc, weight, indexIntoWeight);
|
||||
|
||||
Value addResult = b.create<arith::AddFOp>(loc, weightElem,
|
||||
initTensorElem);
|
||||
Value select =
|
||||
b.create<arith::SelectOp>(loc, indicesIndexWithinBounds,
|
||||
addResult, initTensorElem);
|
||||
b.create<linalg::YieldOp>(loc, select);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
// cast outputType.
|
||||
auto restulType0 = typeConverter->convertType(op->getResult(0).getType());
|
||||
Value castedEmbeddingBagResult =
|
||||
rewriter.create<tensor::CastOp>(loc, restulType0, embeddingBagResult);
|
||||
|
||||
// offset2 tensor, this should be an empty tensor for the sum mode
|
||||
SmallVector<Value> offsetResultSize;
|
||||
Type offsetElemTy = offsetsTy.getElementType();
|
||||
Value zeroDim = rewriter.create<arith::ConstantIndexOp>(loc, /*value=*/0);
|
||||
offsetResultSize.push_back(zeroDim);
|
||||
Value offsetResult = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, offsetResultSize, offsetElemTy);
|
||||
auto resultType1 = typeConverter->convertType(op->getResult(1).getType());
|
||||
Value castedOffsetResult =
|
||||
rewriter.create<tensor::CastOp>(loc, resultType1, offsetResult);
|
||||
|
||||
SmallVector<Value> offsetSize = getTensorSizes(rewriter, loc, offsets);
|
||||
// bagsize, vector of size offset with zeros, I think this is always just
|
||||
// a vector of zeros in the sum mode
|
||||
Value bagSize =
|
||||
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
|
||||
auto resultType2 = typeConverter->convertType(op->getResult(2).getType());
|
||||
Value castedBagSizeResult =
|
||||
rewriter.create<tensor::CastOp>(loc, resultType2, bagSize);
|
||||
|
||||
// max indices, vector of size offset with zeros, this is also always a
|
||||
// vector of zeros in the sum mode. Its mainly used in the max mode.
|
||||
Value indicesOut =
|
||||
createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy);
|
||||
auto resultType3 = typeConverter->convertType(op->getResult(3).getType());
|
||||
Value castedMaxIndices =
|
||||
rewriter.create<tensor::CastOp>(loc, resultType3, indicesOut);
|
||||
|
||||
rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult,
|
||||
castedBagSizeResult, castedMaxIndices});
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Let's say we have an input tensor: initialized with some random values of
|
||||
// size [4, 5, 6]. An index tensor (always 1-d): [0, 2] of size [2], and an
|
||||
|
@ -468,4 +726,6 @@ void mlir::torch::torch_to_linalg::
|
|||
patterns.add<ConvertAtenIndexSelectOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIndexTensorOp>();
|
||||
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenEmbeddingBagPaddingIdxOp>();
|
||||
patterns.add<ConvertAtenEmbeddingBagPaddingIdxOp>(typeConverter, context);
|
||||
}
|
||||
|
|
|
@ -1036,6 +1036,28 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
return;
|
||||
}
|
||||
|
||||
// case for Embedding bag padding idx.
|
||||
if (auto embedding_bag_padding_idx =
|
||||
dyn_cast<AtenEmbeddingBagPaddingIdxOp>(op)) {
|
||||
|
||||
auto resultFloatKnowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
resultFloatKnowledge.dtype = Float32Type::get(op->getContext());
|
||||
|
||||
incorporateKnowledge(embedding_bag_padding_idx.getResult(0),
|
||||
resultFloatKnowledge);
|
||||
auto resultIntKnowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
resultIntKnowledge.dtype =
|
||||
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
|
||||
for (int64_t i = 1; i < 4; i++) {
|
||||
incorporateKnowledge(embedding_bag_padding_idx.getResult(i),
|
||||
resultIntKnowledge);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto softmaxIntOp = dyn_cast<AtenSoftmaxIntOp>(op)) {
|
||||
visitAtenSoftmaxLikeOp(softmaxIntOp, operands);
|
||||
return;
|
||||
|
|
|
@ -6424,6 +6424,68 @@ module {
|
|||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional<list<int>>, %arg7: !torch.bool, %arg8: !torch.optional<int>) -> !torch.tuple<list<int>, list<int>, list<int>, list<int>> {
|
||||
%none = torch.constant.none
|
||||
%str = torch.constant.str "AssertionError: "
|
||||
%int2 = torch.constant.int 2
|
||||
%int1 = torch.constant.int 1
|
||||
%int0 = torch.constant.int 0
|
||||
%0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||
%1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If %1 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int
|
||||
%3 = torch.aten.eq.int %2, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If %3 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%4 = torch.aten.len.t %arg2 : !torch.list<int> -> !torch.int
|
||||
%5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
torch.prim.If %5 -> () {
|
||||
torch.prim.If.yield
|
||||
} else {
|
||||
torch.prim.RaiseException %str, %none : !torch.str, !torch.none
|
||||
torch.prim.If.yield
|
||||
}
|
||||
%6 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%7 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%8 = torch.prim.If %arg7 -> (!torch.int) {
|
||||
%19 = torch.aten.sub.int %7, %int1 : !torch.int, !torch.int -> !torch.int
|
||||
torch.prim.If.yield %19 : !torch.int
|
||||
} else {
|
||||
torch.prim.If.yield %7 : !torch.int
|
||||
}
|
||||
%9 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int
|
||||
%10 = torch.aten.append.t %6, %8 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%11 = torch.aten.append.t %6, %9 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
%12 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||
%13 = torch.aten.eq.int %arg4, %int1 : !torch.int, !torch.int -> !torch.bool
|
||||
%14 = torch.prim.If %13 -> (!torch.list<int>) {
|
||||
%19 = torch.aten.append.t %12, %int0 : !torch.list<int>, !torch.int -> !torch.list<int>
|
||||
torch.prim.If.yield %12 : !torch.list<int>
|
||||
} else {
|
||||
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg1) : (!torch.list<int>) -> !torch.list<int>
|
||||
torch.prim.If.yield %19 : !torch.list<int>
|
||||
}
|
||||
%15 = call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>
|
||||
%16 = torch.aten.eq.int %arg4, %int2 : !torch.int, !torch.int -> !torch.bool
|
||||
%17 = torch.prim.If %16 -> (!torch.list<int>) {
|
||||
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%6) : (!torch.list<int>) -> !torch.list<int>
|
||||
torch.prim.If.yield %19 : !torch.list<int>
|
||||
} else {
|
||||
%19 = func.call @__torch__.torch.jit._shape_functions._copy(%arg2) : (!torch.list<int>) -> !torch.list<int>
|
||||
torch.prim.If.yield %19 : !torch.list<int>
|
||||
}
|
||||
%18 = torch.prim.TupleConstruct %6, %14, %15, %17 : !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>, list<int>, list<int>>
|
||||
return %18 : !torch.tuple<list<int>, list<int>, list<int>, list<int>>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.nll_loss_forward"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {
|
||||
%int-1 = torch.constant.int -1
|
||||
%true = torch.constant.bool true
|
||||
|
|
|
@ -948,6 +948,34 @@ def aten〇index_put〇hacked_twin(self: List[int], indices: List[List[int]], va
|
|||
def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)
|
||||
|
||||
def aten〇embedding_bag〇padding_idx(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
|
||||
assert len(weight) == 2
|
||||
assert len(indices) == 1
|
||||
assert len(offsets) == 1
|
||||
output_bag_shape: List[int] = []
|
||||
out_dim0 = offsets[0]
|
||||
if (include_last_offset):
|
||||
out_dim0 = out_dim0 - 1
|
||||
out_dim1 = weight[1]
|
||||
output_bag_shape.append(out_dim0)
|
||||
output_bag_shape.append(out_dim1)
|
||||
|
||||
offset2bag_shape: List[int] = []
|
||||
if mode == 1:
|
||||
offset2bag_shape.append(0)
|
||||
else:
|
||||
offset2bag_shape = upstream_shape_functions._copy(indices)
|
||||
|
||||
bag_size_shape = upstream_shape_functions._copy(offsets)
|
||||
|
||||
max_indices_shape: List[int] = []
|
||||
if mode == 2:
|
||||
max_indices_shape = upstream_shape_functions._copy(output_bag_shape)
|
||||
else:
|
||||
max_indices_shape = upstream_shape_functions._copy(offsets)
|
||||
|
||||
return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case.
|
||||
Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim.
|
||||
|
|
|
@ -424,6 +424,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)")
|
||||
emit("aten::detach : (Tensor) -> (Tensor)")
|
||||
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
|
||||
emit("aten::embedding_bag.padding_idx : (Tensor, Tensor, Tensor, bool, int, bool, Tensor?, bool, int?) -> (Tensor, Tensor, Tensor, Tensor)")
|
||||
emit("aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
emit("aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||
|
|
|
@ -2583,3 +2583,25 @@ class NumpyTRank0Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: NumpyTRank0Module())
|
||||
def NumpyTRank0Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||
|
||||
class AtenEmbeddingBagSumExample(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
([-1], torch.int64, True),
|
||||
([-1], torch.int64, True),
|
||||
])
|
||||
def forward(self, weight, indices, offsets):
|
||||
return torch.ops.aten.embedding_bag(weight, indices, offsets, scale_grad_by_freq=False, mode=0, sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None)
|
||||
|
||||
@register_test_case(module_factory=lambda: AtenEmbeddingBagSumExample())
|
||||
def AtenEmbeddingBagSumExample_basic(module, tu: TestUtils):
|
||||
weight = torch.rand(100, 10)
|
||||
indices = torch.LongTensor([0, 1, 2, 2, 0, 2, 1, 3, 20, 50, 99, 2, 4, 5, 6, 7, 34, 54])
|
||||
offsets = torch.LongTensor([0, 3, 5, 7, 9, 10, 15])
|
||||
module.forward(weight, indices, offsets)
|
||||
|
|
Loading…
Reference in New Issue