[Torch][Linalg] Add basic support for RNG

This PR include the following pieces:
- Add torch `Generator` type. `Generator` type is converted to i64 in
refbackend type converter.
- Add seed managment support for the default global generator.
`torch_c.getNextSeed` op is used to get the seed. On refbackend, the
`torch_c.getNextSeed` is lowered to load/store from [0] of global
variable `default_generator` memref<i64> in `InsertRngGlobals` pass.
- Add `aten.uniform_` and testing as an example op for RNG ops. Add
`torch.pseudo.aten.uniform` op. It has the same operands and return as
the `aten.uniform_` from the op registry except for value semantics.
pull/557/head snapshot-20220201.241
Yi Zhang 2022-01-11 20:59:42 -05:00
parent 0f083e770a
commit 0cb216a1ad
24 changed files with 521 additions and 14 deletions

View File

@ -50,6 +50,7 @@ from . import constant_alloc
from . import threshold
from . import histogram_binning_calibration
from . import table_batch_embedding
from . import rng
def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']

View File

@ -0,0 +1,44 @@
import torch
from torch_mlir_e2e_test.torchscript.framework import TestUtils
from torch_mlir_e2e_test.torchscript.registry import register_test_case
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class UniformModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
([-1, -1], torch.float64, True),
])
def forward(self, x, y, z):
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
std = torch.cat([
torch.flatten(torch.std(a)),
torch.flatten(torch.std(b)),
torch.flatten(torch.std(c))
])
mean = torch.cat([
torch.flatten(torch.mean(a)),
torch.flatten(torch.mean(b)),
torch.flatten(torch.mean(c))
])
return std, mean
@register_test_case(module_factory=lambda: UniformModule())
def UniformModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(256, 512, 64).double(),
tu.rand(512, 1024, 128).double(),
tu.rand(512, 256, 1024).double())

View File

@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.Generator type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.Generator type
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
/// Gets the !torch.Generator type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.bool type.
//===----------------------------------------------------------------------===//

View File

@ -1325,6 +1325,22 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($threshold)) `->` qualified(type($result))";
}
def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$from,
Torch_FloatType:$to,
TorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($from)) `,` qualified(type($to)) `,` qualified(type($generator)) `->` qualified(type($result))";
}
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -912,4 +912,23 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
}];
}
// The corresponding without underscore variant for `torch.aten.uniform_`
// doesn't exist in the pytorch ops registry. Add it here.
def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
AllowsTypeRefinement,
HasValueSemantics,
]> {
let summary = "`uniform op : (Tensor, float, float, Generator?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$from,
Torch_FloatType:$to,
TorchOptionalGeneratorType:$generator
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
}
#endif // TORCH_OPS

View File

@ -243,6 +243,10 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
let summary = "Torch device";
}
def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> {
let summary = "Torch Generator for producing pseudo-random numbers";
}
def Torch_BoolType : Torch_Type<"Bool", "bool"> {
let summary = "Torch BoolType";
let description = [{
@ -384,6 +388,8 @@ def TorchOptionalStringType:
OptionalOf<Torch_StringType, "Optional torch Str type">;
def TorchOptionalDeviceType:
OptionalOf<Torch_DeviceType, "Optional torch device type">;
def TorchOptionalGeneratorType:
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
class ListOf<list<Type> allowedTypes, string descr> :
@ -430,6 +436,7 @@ def AnyTorchType : AnyTypeOf<[
Torch_BoolType,
Torch_DictType,
Torch_DeviceType,
Torch_GeneratorType,
Torch_ListType,
Torch_LinearParamsType,
Torch_NumberType,

View File

@ -170,4 +170,54 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
}];
}
def TorchConversion_I64ToGeneratorOp : TorchConversion_Op<"i64_to_generator", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert an `i64` to an `Generator`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
I64:$operand
);
let results = (outs
Torch_GeneratorType:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_GeneratorToI64Op : TorchConversion_Op<"generator_to_i64", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `Generator` to a `i64`";
let description = [{
This op is primarily useful as a materialization during dialect conversion.
}];
let arguments = (ins
Torch_GeneratorType:$operand
);
let results = (outs
I64:$result
);
let assemblyFormat = [{
$operand attr-dict
}];
}
def TorchConversion_GetNextSeedOp: TorchConversion_Op<"get_next_seed", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
]> {
let summary = "Get the next global seed";
let description = [{
This op is for getting the next global seed for RNG
}];
let results = (outs
I64:$result
);
let assemblyFormat = [{
attr-dict `:` `(``)` `->` qualified(type($result))
}];
}
#endif // TORCHCONVERSION_OPS

View File

@ -24,6 +24,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
} // namespace RefBackend
} // namespace torch
} // namespace mlir

View File

@ -18,6 +18,12 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO
let dependentDialects = ["memref::MemRefDialect"];
}
def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
let summary = "Insert global variables and sequence to get the next global seed for RNG ops";
let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();";
let dependentDialects = ["memref::MemRefDialect"];
}
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
let summary = "Expand ops into more primitive ops before LLVM lowering.";
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";

View File

@ -84,6 +84,18 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.Generator type.
//===----------------------------------------------------------------------===//
bool torchMlirTypeIsATorchGenerator(MlirType t) {
return unwrap(t).isa<Torch::GeneratorType>();
}
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
return wrap(Torch::GeneratorType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.bool type.
//===----------------------------------------------------------------------===//

View File

@ -21,11 +21,13 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;
using namespace mlir::torch::TorchConversion;
// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
@ -4312,6 +4314,103 @@ public:
};
} // namespace
namespace {
class ConvertPseudoAtenUniformOp
: public OpConversionPattern<PseudoAtenUniformOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
Value self = adaptor.self();
Value from = adaptor.from();
Value to = adaptor.to();
Value generator = adaptor.generator();
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
Type elemTy = resultType.getElementType();
if (!elemTy.isa<mlir::FloatType>())
return rewriter.notifyMatchFailure(op, "This op only support float type");
if (!generator.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure(
op, "The generator has to ben None because only global default "
"generator is supported");
// Build the core formula of LCG Algorithm that makes use of element index:
// For output matrix with rank N:
// temp1 = (cast(I64, index(D.0)) + seed) * multiplier + incrementStep
// ...
// tempN = (cast(I64, index(D.(N))) + tempN-1) * multiplier + incr
// Refer to https://reviews.llvm.org/D101364.
// The value of multiplier and incrementStep are referenced from
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
Value multiplier = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(6364136223846793005));
Value incrementStep = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(1442695040888963407));
// Tn = (index + Tn-1) * multiplier + incrementStep
auto getNextTemp = [&](OpBuilder &b, Value index, Value temp) {
Value castIndex =
b.create<arith::IndexCastOp>(loc, b.getI64Type(), index);
Value add = b.create<arith::AddIOp>(loc, castIndex, temp);
Value mult = b.create<arith::MulIOp>(loc, add, multiplier);
return b.create<arith::AddIOp>(loc, mult, incrementStep);
};
// Get initial seed, min and max used by `linalg.generic` compute payload.
Value initialSeed = rewriter.create<GetNextSeedOp>(loc);
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
// Construct the `linalg.generic` op.
auto resultRank = resultType.getRank();
SmallVector<AffineMap, 1> indexingMaps(
1, rewriter.getMultiDimIdentityMap(resultRank));
SmallVector<StringRef> iteratorTypes(resultRank,
getParallelIteratorTypeName());
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, self);
Value initTensor =
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
Value uniformRes =
rewriter
.create<linalg::GenericOp>(
loc, resultType, /*inputs=*/ValueRange{},
/*outputs=*/initTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value temp = initialSeed;
for (int i = 0; i < resultRank; i++) {
Value index = b.create<linalg::IndexOp>(loc, i);
temp = getNextTemp(b, index, temp);
}
// scale = (max - min) * const(F64, 5.4210108E-20)
// which is derived from rand(min,max) =
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
Value epsilon = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(min.getType(), 5.4210108E-20));
Value range = b.create<arith::SubFOp>(loc, max, min);
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
// res = cast(F64, tempN) * scale + min
Value updateFloat =
b.create<arith::UIToFPOp>(loc, elemTy, temp);
Value updateScaled =
b.create<arith::MulFOp>(loc, updateFloat, scale);
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
b.create<linalg::YieldOp>(loc, res);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, uniformRes);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -4335,6 +4434,7 @@ public:
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
math::MathDialect, tensor::TensorDialect,
arith::ArithmeticDialect>();
target.addLegalOp<GetNextSeedOp>();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
@ -4433,6 +4533,8 @@ public:
target.addIllegalOp<AtenArangeStartStepOp>();
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
target.addIllegalOp<AtenIndexTensorOp>();
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context);
target.addIllegalOp<PseudoAtenUniformOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -117,6 +117,29 @@ public:
};
} // namespace
// Reduce Ops without value semantics but the corresponding without trailing
// underscore variant doesn't exist.
namespace {
class ReduceNonValueSemanticOps : public RewritePattern {
public:
ReduceNonValueSemanticOps(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (!isa<AtenUniform_Op>(op))
return failure();
Operation *newOp = rewriter.create<PseudoAtenUniformOp>(
op->getLoc(), op->getResultTypes(), op->getOperands());
auto tensor =
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
rewriter.replaceOp(op, op->getOperand(0));
return success();
}
};
} // namespace
namespace {
// Reduce the "trailing underscore inplace variant" to the value semantic
// variant + an overwrite of the original "self" argument.
@ -174,9 +197,11 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
patterns.add<ConvertToImmutableTensors>(context);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
patterns.add<ReduceNonValueSemanticOps>(context);
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.addIllegalOp<AtenUniform_Op>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
auto hasValueSemantics = [](Type t) {

View File

@ -242,7 +242,8 @@ public:
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp>(op)) {
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp>(
op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]);
}

View File

@ -97,8 +97,11 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
// Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless()))
return None;
// Other input type to be converted to i64 are handled by other
// materializers.
if (!inputs[0].getType().isa<Torch::IntType>())
return None;
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<Torch::IntType>());
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
@ -134,12 +137,43 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
typeConverter.addArgumentMaterialization(sourceMaterialization);
}
static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
TypeConverter &typeConverter) {
target.addLegalOp<TorchConversion::GeneratorToI64Op,
TorchConversion::I64ToGeneratorOp>();
typeConverter.addConversion([](Torch::GeneratorType type) -> Optional<Type> {
return IntegerType::get(type.getContext(), 64);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder,
IntegerType type, ValueRange inputs,
Location loc) -> Optional<Value> {
// Other builtin integer types could be handled by other materializers.
if (!(type.getWidth() == 64 && type.isSignless()))
return None;
// Other input type to be converted to i64 are handled by other
// materializers.
if (!inputs[0].getType().isa<Torch::GeneratorType>())
return None;
assert(inputs.size() == 1);
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
});
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<IntegerType>());
return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
typeConverter.addArgumentMaterialization(sourceMaterialization);
}
void mlir::torch::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) {
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}
//===----------------------------------------------------------------------===//
@ -255,8 +289,8 @@ struct FinalizingBackendTypeConversionPass
// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op>(target, patterns,
typeConverter);
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
GeneratorToI64Op>(target, patterns, typeConverter);
// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/DialectConversion.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
#include "mlir/IR/BuiltinOps.h"
@ -55,6 +56,8 @@ class VerifyLinalgOnTensorsBackendContractPass
// Structural operations.
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
// Basic scalar operations.
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);

View File

@ -21,9 +21,10 @@
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Transforms/DialectConversion.h"
#include "set"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/RefBackend/Passes.h"
#include <numeric>
#include <set>
using namespace mlir;
using namespace mlir::torch;
@ -112,6 +113,10 @@ static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
static LogicalResult mungeFunction(
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
// Only need to call mungeFunction for functions callable from outside of the
// module.
if (func.isPrivate())
return success();
// Add `llvm.emit_c_interface`.
// This allows ExecutionEngine to resolve the symbol properly.
addEmitCInterfaceAttr(func);
@ -163,9 +168,10 @@ static LogicalResult mungeFunction(
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError("must have one return value of memref types or scalar types "
"of i32, i64, f32, f64, i1, or two return values of memref "
"f32 and i64, or three return values of memref f32");
op.emitError("Supported return types:"
"mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64,"
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
"(mrf32, mrf32, mrf32)");
isSupported = false;
}
@ -193,9 +199,18 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
Type f32 = b.getF32Type();
Type f64 = b.getF64Type();
SmallVector<TypeRange> supportedReturnTypes = {
mri1, mri32, mri64, mrf32, mrf64,
i64, f32, f64, {mrf32, mri64}, {mrf32, mrf32, mrf32}};
SmallVector<TypeRange> supportedReturnTypes = {mri1,
mri32,
mri64,
mrf32,
mrf64,
i64,
f32,
f64,
{mrf32, mri64},
{mrf32, mrf32},
{mrf64, mrf64},
{mrf32, mrf32, mrf32}};
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
@ -234,6 +249,80 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
return std::make_unique<MungeCallingConventions>();
}
//===----------------------------------------------------------------------===//
// InsertRngGlobals
//===----------------------------------------------------------------------===//
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
// Declare a memref<i64> global variable for the seed.
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
b.setInsertionPointToStart(module.getBody());
Type elemTy = b.getI64Type();
auto memref0D = MemRefType::get({}, elemTy);
auto tensor0D = RankedTensorType::get({}, elemTy);
b.create<memref::GlobalOp>(
UnknownLoc::get(b.getContext()), getSeedGobalVarName(),
/*sym_visibility=*/b.getStringAttr("private"),
/*type=*/memref0D,
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}),
/*constant=*/false,
/*alignment=*/nullptr);
}
// Generate sequence for getting the next seed with LCG step:
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64.
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
// Get the current seed value.
auto memref1DType = MemRefType::get({}, b.getI64Type());
Value globalVar =
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
// The value of multiplier and incrementStep are referenced from
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
Value multiplier = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(6364136223846793005));
Value incrementStep = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(1442695040888963407));
// temp = multiplier * currentSeed + incrementStep
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep);
// temp mod 64 = temp & 63
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
return nextSeed;
}
// The global seed is stored into a memref<i64> global variable as the only
// element.
namespace {
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> {
void runOnOperation() override {
auto module = getOperation();
OpBuilder b(module.getBodyRegion());
createGlobalVariableForSeed(b, module);
SmallVector<Operation *> toErase;
module.walk([&](TorchConversion::GetNextSeedOp op) {
b.setInsertionPoint(op);
Value seed = lowerGetNextSeed(b, op.getLoc());
op.replaceAllUsesWith(seed);
toErase.push_back(op);
});
for (auto op : toErase)
op->erase();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::torch::RefBackend::createInsertRngGlobalsPass() {
return std::make_unique<InsertRngGlobals>();
}
//===----------------------------------------------------------------------===//
// ExpandOpsForLLVM
//===----------------------------------------------------------------------===//

View File

@ -237,6 +237,8 @@ TORCH_TYPE_TO_ODS_TYPE = {
"Any": "AnyTorchType",
"Device": "Torch_DeviceType",
"Device?": "TorchOptionalDeviceType",
"Generator": "Torch_GeneratorType",
"Generator?": "TorchOptionalGeneratorType",
"str": "Torch_StringType",
"str?": "TorchOptionalStringType",
"str[]": "TorchStringListType",
@ -497,6 +499,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
# Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist.
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")

View File

@ -210,6 +210,9 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
case TypeKind::DeviceObjType: {
return torchMlirTorchDeviceTypeGet(context);
}
case TypeKind::GeneratorType: {
return torchMlirTorchGeneratorTypeGet(context);
}
default: {
std::stringstream message;
message << "unable to map Torch type '" << *torchType << "' to MLIR type";

View File

@ -73,6 +73,22 @@ class RefBackendInvoker:
arg1,
np.int64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf32_mrf32(arg0, arg1):
self.result = unranked_memref_to_numpy(
arg0, np.float32), unranked_memref_to_numpy(
arg1,
np.float32)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mrf64_mrf64(arg0, arg1):
self.result = unranked_memref_to_numpy(
arg0, np.float64), unranked_memref_to_numpy(
arg1,
np.float64)
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor),
ctypes.POINTER(UnrankedMemRefDescriptor))
@ -110,6 +126,14 @@ class RefBackendInvoker:
"refbackend_consume_func_return_mrf32_mri64",
consume_return_mrf32_mri64)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32",
consume_return_mrf32_mrf32)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf64_mrf64",
consume_return_mrf64_mrf64)
self.ee.register_runtime(
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
consume_return_mrf32_mrf32_mrf32)
@ -148,6 +172,9 @@ LOWERING_PIPELINE = ",".join([
# returns void at the C level -- we get the return value by providing the
# callback).
"refback-munge-calling-conventions",
# Insert global variable and instruction sequence for getting the next
# global seed used in stateful rng.
"refback-insert-rng-globals",
# Lower to LLVM
"builtin.func(convert-linalg-to-loops)",
"builtin.func(lower-affine)",

View File

@ -155,7 +155,7 @@ class ValueReport:
)
if value.dtype != golden.dtype:
return self._record_failure(
f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
)
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
return self._record_failure(

View File

@ -124,9 +124,23 @@ func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<!torch.vtensor<[2,3],si64>> -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
// CHECK: }
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>> -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @torch.aten.uniform_(
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
// CHECK: return %[[T]] : !torch.tensor
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
return %ret : !torch.tensor
}

View File

@ -42,6 +42,16 @@ func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
return %1 : f64
}
// CHECK-LABEL: func @eliminate_materializations$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64
// CHECK: }
func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
%0 = torch_c.i64_to_generator %arg0
%1 = torch_c.generator_to_i64 %0
return %1 : i64
}
// -----
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {

View File

@ -108,3 +108,10 @@ func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
return %arg0 : !torch.float
}
// CHECK-LABEL: func @identity$torch.Generator(
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
// CHECK: return %[[VAL_0]] : i64
func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator {
return %arg0 : !torch.Generator
}

View File

@ -0,0 +1,20 @@
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
// CHECK-LABEL: func @f() -> i64 {
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
// CHECK: return %[[NEXT_SEED]] : i64
module {
func @f() -> i64 {
%seed = torch_c.get_next_seed : () -> i64
return %seed : i64
}
}