[MLIR][TORCH] Add E2E support for aten.randint.low op

Signed-Off By: Vivek Khandelwal<vivek@nod-labs.com>
pull/1589/head snapshot-20221116.659
Vivek Khandelwal 2022-11-06 18:14:05 +05:30
parent 22a5067242
commit a1d3afdba9
7 changed files with 140 additions and 1 deletions

View File

@ -3334,6 +3334,35 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [
}];
}
def Torch_AtenRandintLowOp : Torch_Op<"aten.randint.low", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)`";
let arguments = (ins
Torch_IntType:$low,
Torch_IntType:$high,
AnyTorchListOfTorchIntType:$size,
AnyTorchOptionalIntType:$dtype,
AnyTorchOptionalIntType:$layout,
AnyTorchOptionalDeviceType:$device,
AnyTorchOptionalBoolType:$pin_memory
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRandintLowOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void AtenRandintLowOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenBernoulliTensorOp : Torch_Op<"aten.bernoulli.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -241,7 +241,7 @@ public:
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value startPlusLength =
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, length);
rewriter.replaceOpWithNewOp<AtenSliceTensorOp>(
op, op.getResult().getType(), op.self(), /*dim=*/dim, /*start=*/start,
/*end=*/startPlusLength, /*step=*/one);
@ -3063,6 +3063,56 @@ public:
};
} // namespace
namespace {
class DecomposeAtenRandintLowOp : public OpRewritePattern<AtenRandintLowOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRandintLowOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type resultType = op.getType();
BaseTensorType resultTensorType = resultType.cast<BaseTensorType>();
int64_t cstLow, cstHigh;
if (!matchPattern(op.low(), m_TorchConstantInt(&cstLow)))
return rewriter.notifyMatchFailure(
op, "unimplemented: low must be a constant integer");
if (!matchPattern(op.high(), m_TorchConstantInt(&cstHigh)))
return rewriter.notifyMatchFailure(
op, "unimplemented: high must be a constant integer");
Value none = rewriter.create<ConstantNoneOp>(loc);
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
Value low = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)cstLow));
Value high = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)cstHigh));
BaseTensorType floatResultType =
resultTensorType
.getWithSizesAndDtype(resultTensorType.getSizes(),
rewriter.getF32Type())
.cast<BaseTensorType>();
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, floatResultType, op.size(), /*dtype=*/none, /*layout=*/op.layout(),
/*device=*/op.device(), /*pin_memory=*/op.pin_memory(),
/*memory_format=*/none);
Value result =
rewriter.create<AtenUniformOp>(loc, floatResultType, emptyTensor,
/*from=*/low,
/*to=*/high,
/*generator=*/none);
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(
op, resultType, result,
getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()),
/*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -3264,6 +3314,8 @@ public:
target.addIllegalOp<AtenIndexTensorHackedTwinOp>();
patterns.add<DecomposeAtenMseLossOp>(context);
target.addIllegalOp<AtenMseLossOp>();
patterns.add<DecomposeAtenRandintLowOp>(context);
target.addIllegalOp<AtenRandintLowOp>();
for (std::string opName : legalOps) {
target.addLegalOp(OperationName(opName, context));

View File

@ -1158,6 +1158,17 @@ void TypeAnalysis::visitOperation(Operation *op,
return;
}
if (auto randIntLow = dyn_cast<AtenRandintLowOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type defaultDtype =
IntegerType::get(op->getContext(), 64, IntegerType::Signed);
knowledge.dtype =
getDtypeOrDefault(op->getContext(), randIntLow.dtype(), defaultDtype);
incorporateKnowledge(randIntLow.getResult(), knowledge);
return;
}
// Otherwise, this is an unknown operation, so reset the state.
setAllToEntryStates(results);
return;

View File

@ -6327,6 +6327,9 @@ StringRef mlir::torch::Torch::getShapeLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.randint.low\"(%arg0: !torch.int, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" %0 = torch.derefine %arg0 : !torch.float to !torch.union<float, int>\n"
" %1 = torch.derefine %arg1 : !torch.float to !torch.union<float, int>\n"

View File

@ -828,6 +828,9 @@ def atencumsum(self: List[int], dim: int, dtype: Optional[int] = None) -> Lis
def atenrand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
return self
def atenrandintlow(low: int, high: int, size: List[int], dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size
def atenarangestart_step(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory)

View File

@ -324,6 +324,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)")
emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)")
emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit_with_mutating_variants("aten::bernoulli.Tensor : (Tensor, Tensor, Generator?) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")

View File

@ -276,3 +276,43 @@ class RandLikeDtypeModule(torch.nn.Module):
@register_test_case(module_factory=lambda: RandLikeDtypeModule())
def RandLikeDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1024, 1024).double())
# ==============================================================================
class RandIntLowModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randint(low=1, high=1000, size=[1024, 1024])
mean = torch.mean(a.to(torch.float32))
return mean
@register_test_case(module_factory=lambda: RandIntLowModule())
def RandIntLowModule_basic(module, tu: TestUtils):
module.forward()
# ==============================================================================
class RandIntLowDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
])
def forward(self):
a = torch.ops.aten.randint(low=1, high=1000, size=[128, 256, 512], dtype=torch.float64)
mean = torch.mean(a)
return mean
@register_test_case(module_factory=lambda: RandIntLowDtypeModule())
def RandIntLowDtypeModule_basic(module, tu: TestUtils):
module.forward()