mirror of https://github.com/llvm/torch-mlir
add cumprod
parent
b3942ff984
commit
4945e3a7d0
|
@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
|
||||||
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
Type elemTy);
|
Type elemTy);
|
||||||
|
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
Type elemTy);
|
||||||
|
|
||||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v);
|
||||||
|
|
||||||
|
|
|
@ -1497,6 +1497,79 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertAtenCumprodOp : public OpConversionPattern<AtenCumprodOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value input = adaptor.getSelf();
|
||||||
|
auto resultType = cast<RankedTensorType>(
|
||||||
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
|
Type elementType = resultType.getElementType();
|
||||||
|
Type inputElementType =
|
||||||
|
cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
|
||||||
|
// Converting the input element type to the result's element type.
|
||||||
|
// The only possible mismatch would be when the input element type is an
|
||||||
|
// integer but not `si64`. Therefore, we directly convert the input to
|
||||||
|
// `si64`. Rest all cases are handled in the dtype definition for this op.
|
||||||
|
if (elementType != inputElementType) {
|
||||||
|
Value torchInput = convertTensorToDtype(
|
||||||
|
rewriter, loc, op.getSelf(),
|
||||||
|
rewriter.getIntegerType(64, IntegerType::Signed));
|
||||||
|
input = typeConverter->materializeTargetConversion(
|
||||||
|
rewriter, loc, typeConverter->convertType(torchInput.getType()),
|
||||||
|
torchInput);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t inputRank = resultType.getRank();
|
||||||
|
Value dtype = op.getDtype();
|
||||||
|
if (!isa<Torch::NoneType>(dtype.getType()))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unsupported: dtype argument not supported");
|
||||||
|
|
||||||
|
int64_t dim;
|
||||||
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: only constant dim value is supported");
|
||||||
|
dim = toPositiveDim(dim, inputRank);
|
||||||
|
if (!isValidDim(dim, inputRank))
|
||||||
|
return rewriter.notifyMatchFailure(op, "invalid dim");
|
||||||
|
|
||||||
|
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, input);
|
||||||
|
Value output = createOneInitTensor(rewriter, loc, sizes, elementType);
|
||||||
|
output = rewriter.create<tensor::CastOp>(loc, resultType, output);
|
||||||
|
|
||||||
|
SmallVector<Value> accSizes(sizes);
|
||||||
|
accSizes.erase(accSizes.begin() + dim);
|
||||||
|
SmallVector<int64_t> accStatic(
|
||||||
|
makeShapeTorchCompatible(resultType.getShape()));
|
||||||
|
accStatic.erase(accStatic.begin() + dim);
|
||||||
|
Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType);
|
||||||
|
Type accType =
|
||||||
|
RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType);
|
||||||
|
acc = rewriter.create<tensor::CastOp>(loc, accType, acc);
|
||||||
|
|
||||||
|
Value result = createTMTensorScanOp(
|
||||||
|
rewriter, loc, input, output, acc, dim, /*inclusive=*/true,
|
||||||
|
[](OpBuilder &b, Location loc, Value input, Value acc) {
|
||||||
|
Value prod =
|
||||||
|
(isa<mlir::FloatType>(input.getType())
|
||||||
|
? b.create<arith::MulFOp>(loc, input, acc)->getResult(0)
|
||||||
|
: b.create<arith::MulIOp>(loc, input, acc)->getResult(0));
|
||||||
|
b.create<TMTensor::YieldOp>(loc, prod);
|
||||||
|
});
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -2185,6 +2258,8 @@ public:
|
||||||
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
patterns.add<ConvertAtenSortOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenCumsumOp>();
|
target.addIllegalOp<AtenCumsumOp>();
|
||||||
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<AtenCumprodOp>();
|
||||||
|
patterns.add<ConvertAtenCumprodOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
target.addIllegalOp<AtenScaledDotProductAttentionOp>();
|
||||||
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
patterns.add<ConvertAtenScaledDotProductAttentionOp>(typeConverter,
|
||||||
context);
|
context);
|
||||||
|
|
|
@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
|
Type elemTy) {
|
||||||
|
Value initTensor =
|
||||||
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||||
|
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||||
|
Value c1 =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getOneAttr(type.getElementType()));
|
||||||
|
return b.create<linalg::FillOp>(loc, c1, initTensor).getResult(0);
|
||||||
|
}
|
||||||
|
|
||||||
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
|
||||||
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
assert(isa<IntegerType>(v.getType()) && "must be called with integer type");
|
||||||
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
|
||||||
|
|
|
@ -9072,6 +9072,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
" return %arg0 : !torch.list<int>\n"
|
" return %arg0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
@ -11754,6 +11757,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" }\n"
|
" }\n"
|
||||||
" return %1 : !torch.int\n"
|
" return %1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.optional<int>) -> !torch.int {\n"
|
||||||
|
" %int4 = torch.constant.int 4\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
|
||||||
|
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
|
||||||
|
" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional<int> -> !torch.int\n"
|
||||||
|
" torch.prim.If.yield %2 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %4 = torch.prim.If %3 -> (!torch.int) {\n"
|
||||||
|
" torch.prim.If.yield %int4 : !torch.int\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %2#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %4 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
|
" return %1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
|
|
|
@ -1412,6 +1412,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b
|
||||||
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -2888,6 +2891,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt
|
||||||
return torch.int64
|
return torch.int64
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
|
||||||
|
@check_dtype_function(
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) +
|
||||||
|
_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32))
|
||||||
|
def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int:
|
||||||
|
if dtype is not None:
|
||||||
|
return dtype
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
if is_integer_dtype(self_dtype):
|
||||||
|
return torch.int64
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
|
||||||
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -4683,6 +4683,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
ones = torch.ones([1], dtype=torch.int32)
|
||||||
|
return torch.ops.aten.cumprod(val, ones.item())
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodModule())
|
||||||
|
def CumprodModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodStaticModule())
|
||||||
|
def CumprodStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodStaticNegativeDimModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule())
|
||||||
|
def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(2, 7, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class CumprodInputDtypeInt32Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([2, 7, 4], torch.int32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, val):
|
||||||
|
return torch.ops.aten.cumprod(val, 1)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module())
|
||||||
|
def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.randint(2, 7, 4).to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class AtenToDeviceModule(torch.nn.Module):
|
class AtenToDeviceModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue