mirror of https://github.com/llvm/torch-mlir
[Torch][Linalg] Add basic support for RNG
This PR include the following pieces: - Add torch `Generator` type. `Generator` type is converted to i64 in refbackend type converter. - Add seed managment support for the default global generator. `torch_c.getNextSeed` op is used to get the seed. On refbackend, the `torch_c.getNextSeed` is lowered to load/store from [0] of global variable `default_generator` memref<i64> in `InsertRngGlobals` pass. - Add `aten.uniform_` and testing as an example op for RNG ops. Add `torch.pseudo.aten.uniform` op. It has the same operands and return as the `aten.uniform_` from the op registry except for value semantics.pull/557/head snapshot-20220201.241
parent
0f083e770a
commit
0cb216a1ad
|
@ -50,6 +50,7 @@ from . import constant_alloc
|
|||
from . import threshold
|
||||
from . import histogram_binning_calibration
|
||||
from . import table_batch_embedding
|
||||
from . import rng
|
||||
|
||||
def _get_argparse():
|
||||
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'external']
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
|
||||
from torch_mlir_e2e_test.torchscript.framework import TestUtils
|
||||
from torch_mlir_e2e_test.torchscript.registry import register_test_case
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
class UniformModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
([-1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, x, y, z):
|
||||
a = torch.ops.aten.uniform_(x, 1.0, 10.0)
|
||||
b = torch.ops.aten.uniform_(y, -20.0, -5.0)
|
||||
c = torch.ops.aten.uniform_(z, -15.0, 3.0)
|
||||
std = torch.cat([
|
||||
torch.flatten(torch.std(a)),
|
||||
torch.flatten(torch.std(b)),
|
||||
torch.flatten(torch.std(c))
|
||||
])
|
||||
mean = torch.cat([
|
||||
torch.flatten(torch.mean(a)),
|
||||
torch.flatten(torch.mean(b)),
|
||||
torch.flatten(torch.mean(c))
|
||||
])
|
||||
return std, mean
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: UniformModule())
|
||||
def UniformModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.rand(256, 512, 64).double(),
|
||||
tu.rand(512, 1024, 128).double(),
|
||||
tu.rand(512, 256, 1024).double())
|
||||
|
|
@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
|
|||
/// Gets the !torch.Device type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Generator type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.Generator type
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
|
||||
|
||||
/// Gets the !torch.Generator type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.bool type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1325,6 +1325,22 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
|||
let assemblyFormat = "$grad_output `,` $self `,` $threshold attr-dict `:` qualified(type($grad_output)) `,` qualified(type($self)) `,` qualified(type($threshold)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$from,
|
||||
Torch_FloatType:$to,
|
||||
TorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` qualified(type($self)) `,` qualified(type($from)) `,` qualified(type($to)) `,` qualified(type($generator)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenTriuOp : Torch_Op<"aten.triu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -912,4 +912,23 @@ def Torch_OverwriteTensorOp : Torch_Op<"overwrite.tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
// The corresponding without underscore variant for `torch.aten.uniform_`
|
||||
// doesn't exist in the pytorch ops registry. Add it here.
|
||||
def Torch_PseudoAtenUniformOp: Torch_Op<"pseudo.aten.uniform", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
]> {
|
||||
let summary = "`uniform op : (Tensor, float, float, Generator?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_FloatType:$from,
|
||||
Torch_FloatType:$to,
|
||||
TorchOptionalGeneratorType:$generator
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $from `,` $to `,` $generator attr-dict `:` type($self) `,` type($from) `,` type($to) `,` type($generator) `->` type($result)";
|
||||
}
|
||||
|
||||
#endif // TORCH_OPS
|
||||
|
|
|
@ -243,6 +243,10 @@ def Torch_DeviceType : Torch_Type<"Device", "Device"> {
|
|||
let summary = "Torch device";
|
||||
}
|
||||
|
||||
def Torch_GeneratorType : Torch_Type<"Generator", "Generator"> {
|
||||
let summary = "Torch Generator for producing pseudo-random numbers";
|
||||
}
|
||||
|
||||
def Torch_BoolType : Torch_Type<"Bool", "bool"> {
|
||||
let summary = "Torch BoolType";
|
||||
let description = [{
|
||||
|
@ -384,6 +388,8 @@ def TorchOptionalStringType:
|
|||
OptionalOf<Torch_StringType, "Optional torch Str type">;
|
||||
def TorchOptionalDeviceType:
|
||||
OptionalOf<Torch_DeviceType, "Optional torch device type">;
|
||||
def TorchOptionalGeneratorType:
|
||||
OptionalOf<Torch_GeneratorType, "Optional torch Generator type">;
|
||||
|
||||
def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">;
|
||||
class ListOf<list<Type> allowedTypes, string descr> :
|
||||
|
@ -430,6 +436,7 @@ def AnyTorchType : AnyTypeOf<[
|
|||
Torch_BoolType,
|
||||
Torch_DictType,
|
||||
Torch_DeviceType,
|
||||
Torch_GeneratorType,
|
||||
Torch_ListType,
|
||||
Torch_LinearParamsType,
|
||||
Torch_NumberType,
|
||||
|
|
|
@ -170,4 +170,54 @@ def TorchConversion_FromF64Op : TorchConversion_Op<"from_f64", [
|
|||
}];
|
||||
}
|
||||
|
||||
def TorchConversion_I64ToGeneratorOp : TorchConversion_Op<"i64_to_generator", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Convert an `i64` to an `Generator`";
|
||||
let description = [{
|
||||
This op is primarily useful as a materialization during dialect conversion.
|
||||
}];
|
||||
let arguments = (ins
|
||||
I64:$operand
|
||||
);
|
||||
let results = (outs
|
||||
Torch_GeneratorType:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def TorchConversion_GeneratorToI64Op : TorchConversion_Op<"generator_to_i64", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Convert a `Generator` to a `i64`";
|
||||
let description = [{
|
||||
This op is primarily useful as a materialization during dialect conversion.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Torch_GeneratorType:$operand
|
||||
);
|
||||
let results = (outs
|
||||
I64:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$operand attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
def TorchConversion_GetNextSeedOp: TorchConversion_Op<"get_next_seed", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
]> {
|
||||
let summary = "Get the next global seed";
|
||||
let description = [{
|
||||
This op is for getting the next global seed for RNG
|
||||
}];
|
||||
let results = (outs
|
||||
I64:$result
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
attr-dict `:` `(``)` `->` qualified(type($result))
|
||||
}];
|
||||
}
|
||||
#endif // TORCHCONVERSION_OPS
|
||||
|
|
|
@ -24,6 +24,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createMungeCallingConventionsPass();
|
|||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createExpandOpsForLLVMPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createInsertRngGlobalsPass();
|
||||
} // namespace RefBackend
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -18,6 +18,12 @@ def MungeCallingConventions : Pass<"refback-munge-calling-conventions", "ModuleO
|
|||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def InsertRngGlobals: Pass<"refback-insert-rng-globals", "ModuleOp"> {
|
||||
let summary = "Insert global variables and sequence to get the next global seed for RNG ops";
|
||||
let constructor = "mlir::torch::RefBackend::createInsertRngGlobalsPass();";
|
||||
let dependentDialects = ["memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def ExpandOpsForLLVM : Pass<"refback-expand-ops-for-llvm", "FuncOp"> {
|
||||
let summary = "Expand ops into more primitive ops before LLVM lowering.";
|
||||
let constructor = "mlir::torch::RefBackend::createExpandOpsForLLVMPass();";
|
||||
|
|
|
@ -84,6 +84,18 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
|||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Generator type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool torchMlirTypeIsATorchGenerator(MlirType t) {
|
||||
return unwrap(t).isa<Torch::GeneratorType>();
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
|
||||
return wrap(Torch::GeneratorType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.bool type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -21,11 +21,13 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
using namespace mlir::torch::Torch;
|
||||
using namespace mlir::torch::TorchConversion;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Patterns (as this grows, it should be organized into multiple files)
|
||||
|
@ -4312,6 +4314,103 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertPseudoAtenUniformOp
|
||||
: public OpConversionPattern<PseudoAtenUniformOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(PseudoAtenUniformOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
Location loc = op.getLoc();
|
||||
Value self = adaptor.self();
|
||||
Value from = adaptor.from();
|
||||
Value to = adaptor.to();
|
||||
Value generator = adaptor.generator();
|
||||
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
||||
Type elemTy = resultType.getElementType();
|
||||
|
||||
if (!elemTy.isa<mlir::FloatType>())
|
||||
return rewriter.notifyMatchFailure(op, "This op only support float type");
|
||||
|
||||
if (!generator.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
|
||||
// Build the core formula of LCG Algorithm that makes use of element index:
|
||||
// For output matrix with rank N:
|
||||
// temp1 = (cast(I64, index(D.0)) + seed) * multiplier + incrementStep
|
||||
// ...
|
||||
// tempN = (cast(I64, index(D.(N))) + tempN-1) * multiplier + incr
|
||||
// Refer to https://reviews.llvm.org/D101364.
|
||||
// The value of multiplier and incrementStep are referenced from
|
||||
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
||||
Value multiplier = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(6364136223846793005));
|
||||
Value incrementStep = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getI64IntegerAttr(1442695040888963407));
|
||||
// Tn = (index + Tn-1) * multiplier + incrementStep
|
||||
auto getNextTemp = [&](OpBuilder &b, Value index, Value temp) {
|
||||
Value castIndex =
|
||||
b.create<arith::IndexCastOp>(loc, b.getI64Type(), index);
|
||||
Value add = b.create<arith::AddIOp>(loc, castIndex, temp);
|
||||
Value mult = b.create<arith::MulIOp>(loc, add, multiplier);
|
||||
return b.create<arith::AddIOp>(loc, mult, incrementStep);
|
||||
};
|
||||
|
||||
// Get initial seed, min and max used by `linalg.generic` compute payload.
|
||||
Value initialSeed = rewriter.create<GetNextSeedOp>(loc);
|
||||
Value min = convertScalarToDtype(rewriter, loc, from, elemTy);
|
||||
Value max = convertScalarToDtype(rewriter, loc, to, elemTy);
|
||||
|
||||
// Construct the `linalg.generic` op.
|
||||
auto resultRank = resultType.getRank();
|
||||
SmallVector<AffineMap, 1> indexingMaps(
|
||||
1, rewriter.getMultiDimIdentityMap(resultRank));
|
||||
SmallVector<StringRef> iteratorTypes(resultRank,
|
||||
getParallelIteratorTypeName());
|
||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, self);
|
||||
Value initTensor =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, sizes, elemTy);
|
||||
Value uniformRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
loc, resultType, /*inputs=*/ValueRange{},
|
||||
/*outputs=*/initTensor, indexingMaps, iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value temp = initialSeed;
|
||||
for (int i = 0; i < resultRank; i++) {
|
||||
Value index = b.create<linalg::IndexOp>(loc, i);
|
||||
temp = getNextTemp(b, index, temp);
|
||||
}
|
||||
// scale = (max - min) * const(F64, 5.4210108E-20)
|
||||
// which is derived from rand(min,max) =
|
||||
// rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1
|
||||
Value epsilon = b.create<arith::ConstantOp>(
|
||||
loc, b.getFloatAttr(min.getType(), 5.4210108E-20));
|
||||
Value range = b.create<arith::SubFOp>(loc, max, min);
|
||||
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
|
||||
|
||||
// res = cast(F64, tempN) * scale + min
|
||||
Value updateFloat =
|
||||
b.create<arith::UIToFPOp>(loc, elemTy, temp);
|
||||
Value updateScaled =
|
||||
b.create<arith::MulFOp>(loc, updateFloat, scale);
|
||||
Value res = b.create<arith::AddFOp>(loc, updateScaled, min);
|
||||
b.create<linalg::YieldOp>(loc, res);
|
||||
})
|
||||
.getResult(0);
|
||||
|
||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, uniformRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -4335,6 +4434,7 @@ public:
|
|||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
math::MathDialect, tensor::TensorDialect,
|
||||
arith::ArithmeticDialect>();
|
||||
target.addLegalOp<GetNextSeedOp>();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
|
@ -4433,6 +4533,8 @@ public:
|
|||
target.addIllegalOp<AtenArangeStartStepOp>();
|
||||
patterns.add<ConvertAtenIndexTensorOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIndexTensorOp>();
|
||||
patterns.add<ConvertPseudoAtenUniformOp>(typeConverter, context);
|
||||
target.addIllegalOp<PseudoAtenUniformOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -117,6 +117,29 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Reduce Ops without value semantics but the corresponding without trailing
|
||||
// underscore variant doesn't exist.
|
||||
namespace {
|
||||
class ReduceNonValueSemanticOps : public RewritePattern {
|
||||
public:
|
||||
ReduceNonValueSemanticOps(MLIRContext *context)
|
||||
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!isa<AtenUniform_Op>(op))
|
||||
return failure();
|
||||
|
||||
Operation *newOp = rewriter.create<PseudoAtenUniformOp>(
|
||||
op->getLoc(), op->getResultTypes(), op->getOperands());
|
||||
auto tensor =
|
||||
rewriter.create<CopyToValueTensorOp>(op->getLoc(), newOp->getResult(0));
|
||||
rewriter.create<OverwriteTensorOp>(op->getLoc(), tensor, op->getOperand(0));
|
||||
rewriter.replaceOp(op, op->getOperand(0));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// Reduce the "trailing underscore inplace variant" to the value semantic
|
||||
// variant + an overwrite of the original "self" argument.
|
||||
|
@ -174,9 +197,11 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
patterns.add<ConvertToImmutableTensors>(context);
|
||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||
patterns.add<ReduceNonValueSemanticOps>(context);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.addIllegalOp<AtenUniform_Op>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
|
|
|
@ -242,7 +242,8 @@ public:
|
|||
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
|
||||
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp>(op)) {
|
||||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp>(
|
||||
op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -97,8 +97,11 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
|||
// Other builtin integer types could be handled by other materializers.
|
||||
if (!(type.getWidth() == 64 && type.isSignless()))
|
||||
return None;
|
||||
// Other input type to be converted to i64 are handled by other
|
||||
// materializers.
|
||||
if (!inputs[0].getType().isa<Torch::IntType>())
|
||||
return None;
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<Torch::IntType>());
|
||||
return builder.create<ToI64Op>(loc, inputs[0]).getResult();
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
||||
|
@ -134,12 +137,43 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
|||
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
||||
}
|
||||
|
||||
static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
||||
TypeConverter &typeConverter) {
|
||||
target.addLegalOp<TorchConversion::GeneratorToI64Op,
|
||||
TorchConversion::I64ToGeneratorOp>();
|
||||
typeConverter.addConversion([](Torch::GeneratorType type) -> Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 64);
|
||||
});
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder,
|
||||
IntegerType type, ValueRange inputs,
|
||||
Location loc) -> Optional<Value> {
|
||||
// Other builtin integer types could be handled by other materializers.
|
||||
if (!(type.getWidth() == 64 && type.isSignless()))
|
||||
return None;
|
||||
// Other input type to be converted to i64 are handled by other
|
||||
// materializers.
|
||||
if (!inputs[0].getType().isa<Torch::GeneratorType>())
|
||||
return None;
|
||||
assert(inputs.size() == 1);
|
||||
return builder.create<GeneratorToI64Op>(loc, inputs[0]).getResult();
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<IntegerType>());
|
||||
return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
|
||||
};
|
||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||
typeConverter.addArgumentMaterialization(sourceMaterialization);
|
||||
}
|
||||
|
||||
void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||
setupTorchIntToI64Conversion(target, typeConverter);
|
||||
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -255,8 +289,8 @@ struct FinalizingBackendTypeConversionPass
|
|||
// Mark materializations as illegal in this pass (since we are finalizing)
|
||||
// and add patterns that eliminate them.
|
||||
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
|
||||
FromI64Op, ToI64Op, FromF64Op, ToF64Op>(target, patterns,
|
||||
typeConverter);
|
||||
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
|
||||
GeneratorToI64Op>(target, patterns, typeConverter);
|
||||
|
||||
// If all result types are legal, and all block arguments are legal, then
|
||||
// all types in the program are legal.
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
@ -55,6 +56,8 @@ class VerifyLinalgOnTensorsBackendContractPass
|
|||
// Structural operations.
|
||||
target.addDynamicallyLegalOp<ModuleOp, FuncOp, ReturnOp>(opHasLegalTypes);
|
||||
|
||||
target.addDynamicallyLegalOp<GetNextSeedOp>(opHasLegalTypes);
|
||||
|
||||
// Basic scalar operations.
|
||||
target.addDynamicallyLegalDialect<StandardOpsDialect>(isLegalScalarOp);
|
||||
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
|
||||
|
|
|
@ -21,9 +21,10 @@
|
|||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "set"
|
||||
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
|
||||
#include "torch-mlir/RefBackend/Passes.h"
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -112,6 +113,10 @@ static void replaceReturnWithCall(OpBuilder b, ReturnOp op, StringRef funcName,
|
|||
static LogicalResult mungeFunction(
|
||||
FuncOp func, std::set<std::string> &supportedConsumeFuncReturnFuncs,
|
||||
std::map<std::string, std::vector<Type>> &invokedConsumeFuncReturnFuncs) {
|
||||
// Only need to call mungeFunction for functions callable from outside of the
|
||||
// module.
|
||||
if (func.isPrivate())
|
||||
return success();
|
||||
// Add `llvm.emit_c_interface`.
|
||||
// This allows ExecutionEngine to resolve the symbol properly.
|
||||
addEmitCInterfaceAttr(func);
|
||||
|
@ -163,9 +168,10 @@ static LogicalResult mungeFunction(
|
|||
auto supportedFuncsEnd = supportedConsumeFuncReturnFuncs.end();
|
||||
std::string funcName = getConsumeReturnFunctionNameForReturnTypes(retTypes);
|
||||
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
|
||||
op.emitError("must have one return value of memref types or scalar types "
|
||||
"of i32, i64, f32, f64, i1, or two return values of memref "
|
||||
"f32 and i64, or three return values of memref f32");
|
||||
op.emitError("Supported return types:"
|
||||
"mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64,"
|
||||
"(mrf32, mri64), (mrf32, mrf32), (mrf64, mrf64),"
|
||||
"(mrf32, mrf32, mrf32)");
|
||||
isSupported = false;
|
||||
}
|
||||
|
||||
|
@ -193,9 +199,18 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
|
|||
Type f32 = b.getF32Type();
|
||||
Type f64 = b.getF64Type();
|
||||
|
||||
SmallVector<TypeRange> supportedReturnTypes = {
|
||||
mri1, mri32, mri64, mrf32, mrf64,
|
||||
i64, f32, f64, {mrf32, mri64}, {mrf32, mrf32, mrf32}};
|
||||
SmallVector<TypeRange> supportedReturnTypes = {mri1,
|
||||
mri32,
|
||||
mri64,
|
||||
mrf32,
|
||||
mrf64,
|
||||
i64,
|
||||
f32,
|
||||
f64,
|
||||
{mrf32, mri64},
|
||||
{mrf32, mrf32},
|
||||
{mrf64, mrf64},
|
||||
{mrf32, mrf32, mrf32}};
|
||||
|
||||
llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
|
||||
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
|
||||
|
@ -234,6 +249,80 @@ mlir::torch::RefBackend::createMungeCallingConventionsPass() {
|
|||
return std::make_unique<MungeCallingConventions>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertRngGlobals
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static constexpr StringRef getSeedGobalVarName() { return "global_seed"; }
|
||||
|
||||
// Declare a memref<i64> global variable for the seed.
|
||||
static void createGlobalVariableForSeed(OpBuilder &b, ModuleOp module) {
|
||||
b.setInsertionPointToStart(module.getBody());
|
||||
Type elemTy = b.getI64Type();
|
||||
auto memref0D = MemRefType::get({}, elemTy);
|
||||
auto tensor0D = RankedTensorType::get({}, elemTy);
|
||||
b.create<memref::GlobalOp>(
|
||||
UnknownLoc::get(b.getContext()), getSeedGobalVarName(),
|
||||
/*sym_visibility=*/b.getStringAttr("private"),
|
||||
/*type=*/memref0D,
|
||||
/*initial_value=*/DenseIntElementsAttr::get(tensor0D, {APInt(64, 0)}),
|
||||
/*constant=*/false,
|
||||
/*alignment=*/nullptr);
|
||||
}
|
||||
|
||||
// Generate sequence for getting the next seed with LCG step:
|
||||
// nextSeed = (multiplier * currentSeed + incrementStep) mod 64.
|
||||
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
|
||||
static Value lowerGetNextSeed(OpBuilder &b, Location loc) {
|
||||
// Get the current seed value.
|
||||
auto memref1DType = MemRefType::get({}, b.getI64Type());
|
||||
Value globalVar =
|
||||
b.create<memref::GetGlobalOp>(loc, memref1DType, getSeedGobalVarName());
|
||||
Value currentSeed = b.create<memref::LoadOp>(loc, globalVar);
|
||||
|
||||
// The value of multiplier and incrementStep are referenced from
|
||||
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
|
||||
Value multiplier = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(6364136223846793005));
|
||||
Value incrementStep = b.create<arith::ConstantOp>(
|
||||
loc, b.getI64IntegerAttr(1442695040888963407));
|
||||
// temp = multiplier * currentSeed + incrementStep
|
||||
Value mul = b.create<arith::MulIOp>(loc, currentSeed, multiplier);
|
||||
Value temp = b.create<arith::AddIOp>(loc, mul, incrementStep);
|
||||
// temp mod 64 = temp & 63
|
||||
Value cst127 = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(127));
|
||||
Value nextSeed = b.create<arith::AndIOp>(loc, temp, cst127);
|
||||
b.create<memref::StoreOp>(loc, nextSeed, globalVar);
|
||||
return nextSeed;
|
||||
}
|
||||
|
||||
// The global seed is stored into a memref<i64> global variable as the only
|
||||
// element.
|
||||
namespace {
|
||||
class InsertRngGlobals : public InsertRngGlobalsBase<InsertRngGlobals> {
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
OpBuilder b(module.getBodyRegion());
|
||||
createGlobalVariableForSeed(b, module);
|
||||
SmallVector<Operation *> toErase;
|
||||
module.walk([&](TorchConversion::GetNextSeedOp op) {
|
||||
b.setInsertionPoint(op);
|
||||
Value seed = lowerGetNextSeed(b, op.getLoc());
|
||||
op.replaceAllUsesWith(seed);
|
||||
toErase.push_back(op);
|
||||
});
|
||||
|
||||
for (auto op : toErase)
|
||||
op->erase();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::torch::RefBackend::createInsertRngGlobalsPass() {
|
||||
return std::make_unique<InsertRngGlobals>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ExpandOpsForLLVM
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -237,6 +237,8 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"Any": "AnyTorchType",
|
||||
"Device": "Torch_DeviceType",
|
||||
"Device?": "TorchOptionalDeviceType",
|
||||
"Generator": "Torch_GeneratorType",
|
||||
"Generator?": "TorchOptionalGeneratorType",
|
||||
"str": "Torch_StringType",
|
||||
"str?": "TorchOptionalStringType",
|
||||
"str[]": "TorchStringListType",
|
||||
|
@ -497,6 +499,10 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
|
||||
# Ops without value semantics but the corresponding without trailing
|
||||
# underscore variant doesn't exist.
|
||||
emit("aten::uniform_ : (Tensor, float, float, Generator?) -> (Tensor)")
|
||||
|
||||
emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)")
|
||||
|
||||
|
|
|
@ -210,6 +210,9 @@ MlirType torch_mlir::getMlirTypeFromTorchType(MlirLocation loc,
|
|||
case TypeKind::DeviceObjType: {
|
||||
return torchMlirTorchDeviceTypeGet(context);
|
||||
}
|
||||
case TypeKind::GeneratorType: {
|
||||
return torchMlirTorchGeneratorTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unable to map Torch type '" << *torchType << "' to MLIR type";
|
||||
|
|
|
@ -73,6 +73,22 @@ class RefBackendInvoker:
|
|||
arg1,
|
||||
np.int64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_return_mrf32_mrf32(arg0, arg1):
|
||||
self.result = unranked_memref_to_numpy(
|
||||
arg0, np.float32), unranked_memref_to_numpy(
|
||||
arg1,
|
||||
np.float32)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
def consume_return_mrf64_mrf64(arg0, arg1):
|
||||
self.result = unranked_memref_to_numpy(
|
||||
arg0, np.float64), unranked_memref_to_numpy(
|
||||
arg1,
|
||||
np.float64)
|
||||
|
||||
@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor),
|
||||
ctypes.POINTER(UnrankedMemRefDescriptor))
|
||||
|
@ -110,6 +126,14 @@ class RefBackendInvoker:
|
|||
"refbackend_consume_func_return_mrf32_mri64",
|
||||
consume_return_mrf32_mri64)
|
||||
|
||||
self.ee.register_runtime(
|
||||
"refbackend_consume_func_return_mrf32_mrf32",
|
||||
consume_return_mrf32_mrf32)
|
||||
|
||||
self.ee.register_runtime(
|
||||
"refbackend_consume_func_return_mrf64_mrf64",
|
||||
consume_return_mrf64_mrf64)
|
||||
|
||||
self.ee.register_runtime(
|
||||
"refbackend_consume_func_return_mrf32_mrf32_mrf32",
|
||||
consume_return_mrf32_mrf32_mrf32)
|
||||
|
@ -148,6 +172,9 @@ LOWERING_PIPELINE = ",".join([
|
|||
# returns void at the C level -- we get the return value by providing the
|
||||
# callback).
|
||||
"refback-munge-calling-conventions",
|
||||
# Insert global variable and instruction sequence for getting the next
|
||||
# global seed used in stateful rng.
|
||||
"refback-insert-rng-globals",
|
||||
# Lower to LLVM
|
||||
"builtin.func(convert-linalg-to-loops)",
|
||||
"builtin.func(lower-affine)",
|
||||
|
|
|
@ -155,7 +155,7 @@ class ValueReport:
|
|||
)
|
||||
if value.dtype != golden.dtype:
|
||||
return self._record_failure(
|
||||
f'shape ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
||||
f'dtype ({value.dtype}) is not equal to golden dtype ({golden.dtype})'
|
||||
)
|
||||
if not torch.allclose(value, golden, rtol=1e-03, atol=1e-07, equal_nan=True):
|
||||
return self._record_failure(
|
||||
|
|
|
@ -124,9 +124,23 @@ func @torch.tensor.literal() -> !torch.tensor {
|
|||
// CHECK: %[[VRET:.*]] = torch.aten.index.Tensor %[[SELF_VTENSOR]], %[[INDICES_LIST]] : !torch.vtensor<[5],f32>, !torch.list<!torch.vtensor<[2,3],si64>> -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: return %[[RET]] : !torch.tensor
|
||||
// CHECK: }
|
||||
func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor<[5],f32>, %indices: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||
%tensor_optional_list = torch.prim.ListConstruct %indices : (!torch.tensor<[2,3],si64>) -> !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>>
|
||||
%ret = torch.aten.index.Tensor %self, %tensor_optional_list : !torch.tensor<[5],f32>, !torch.list<!torch.optional<!torch.tensor<[2,3],si64>>> -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.uniform_(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor, %[[MIN:.*]]: !torch.float, %[[MAX:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[GENERATOR:.*]]: !torch.none) -> !torch.tensor {
|
||||
// CHECK: %[[T_VTENSOR:.*]] = torch.copy.to_vtensor %[[T]] : !torch.vtensor
|
||||
// CHECK: %[[VRET:.*]] = torch.pseudo.aten.uniform %[[T_VTENSOR]], %[[MIN]], %[[MAX]], %[[GENERATOR]] :
|
||||
// CHECK-SAME: !torch.vtensor, !torch.float, !torch.float, !torch.none -> !torch.vtensor
|
||||
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
|
||||
// CHECK: %[[COPY_VTENSOR:.*]] = torch.copy.to_vtensor %[[RET]] : !torch.vtensor
|
||||
// CHECK: torch.overwrite.tensor %[[COPY_VTENSOR]] overwrites %[[T]] : !torch.vtensor, !torch.tensor
|
||||
// CHECK: return %[[T]] : !torch.tensor
|
||||
func @torch.aten.uniform_(%t: !torch.tensor, %min: !torch.float, %max: !torch.float, %generator: !torch.none) -> !torch.tensor {
|
||||
%ret = torch.aten.uniform_ %t, %min, %max, %generator: !torch.tensor, !torch.float, !torch.float, !torch.none -> !torch.tensor
|
||||
return %ret : !torch.tensor
|
||||
}
|
||||
|
|
|
@ -42,6 +42,16 @@ func @eliminate_materializations$torch.float(%arg0: f64) -> f64 {
|
|||
return %1 : f64
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @eliminate_materializations$torch.Generator(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
|
||||
// CHECK: return %[[VAL_0]] : i64
|
||||
// CHECK: }
|
||||
func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 {
|
||||
%0 = torch_c.i64_to_generator %arg0
|
||||
%1 = torch_c.generator_to_i64 %0
|
||||
return %1 : i64
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @unable_to_convert_lone_buffer_cast() -> tensor<f32> {
|
||||
|
|
|
@ -108,3 +108,10 @@ func @identity$torch.int(%arg0: !torch.int) -> !torch.int {
|
|||
func @identity$torch.float(%arg0: !torch.float) -> !torch.float {
|
||||
return %arg0 : !torch.float
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @identity$torch.Generator(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: i64) -> i64 {
|
||||
// CHECK: return %[[VAL_0]] : i64
|
||||
func @identity$torch.Generator(%arg0: !torch.Generator) -> !torch.Generator {
|
||||
return %arg0 : !torch.Generator
|
||||
}
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
// RUN: torch-mlir-opt %s -refback-insert-rng-globals -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: memref.global "private" @global_seed : memref<i64> = dense<0>
|
||||
// CHECK-LABEL: func @f() -> i64 {
|
||||
// CHECK: %[[MEMREF:.*]] = memref.get_global @global_seed : memref<i64>
|
||||
// CHECK: %[[SEED:.*]] = memref.load %[[MEMREF]][] : memref<i64>
|
||||
// CHECK: %[[MULTIPLIER:.*]] = arith.constant 6364136223846793005 : i64
|
||||
// CHECK: %[[INC:.*]] = arith.constant 1442695040888963407 : i64
|
||||
// CHECK: %[[MUL:.*]] = arith.muli %[[SEED]], %[[MULTIPLIER]] : i64
|
||||
// CHECK: %[[TEMP:.*]] = arith.addi %[[MUL]], %[[INC]] : i64
|
||||
// CHECK: %[[CST127:.*]] = arith.constant 127 : i64
|
||||
// CHECK: %[[NEXT_SEED:.*]] = arith.andi %[[TEMP]], %[[CST127]] : i64
|
||||
// CHECK: memref.store %[[NEXT_SEED]], %[[MEMREF]][] : memref<i64>
|
||||
// CHECK: return %[[NEXT_SEED]] : i64
|
||||
module {
|
||||
func @f() -> i64 {
|
||||
%seed = torch_c.get_next_seed : () -> i64
|
||||
return %seed : i64
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue