mirror of https://github.com/llvm/torch-mlir
[Stablehlo] support uint8 (#3367)
Support lowering unsigned integer type to stablehlo as discussed in https://github.com/llvm/torch-mlir/pull/2184. The things I do in this PR: 1. create `setupBackendTypeConversionForStablehlo()`, `createFuncBackendTypeConversionForStablehloPass` and `createFinalizingBackendTypeConversionForStablehloPass`. 2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`, because it's different result type between linalg backend and stablehlo backend: ``` // linalg backend func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> { %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8> %0 = tensor.empty() : tensor<3xf32> %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) { ^bb0(%in: i8, %out: f32): %2 = arith.uitofp %in : i8 to f32 linalg.yield %2 : f32 } -> tensor<3xf32> return %1 : tensor<3xf32> } // stablehlo backend func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> { %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8> %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32> return %0 : tensor<3xf32> } ``` 3. fix stablehlo and linalg's conversionpull/3414/head
parent
56d21cba62
commit
50f7103098
|
@ -1 +1 @@
|
|||
Subproject commit c44d9af8d4879adccf1054cb61a53377ae5898cb
|
||||
Subproject commit 25d237f6273361bb29e8436349c7067ee559dca2
|
|
@ -25,9 +25,7 @@ class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
|
|||
// Conversions to backend types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> {
|
||||
let summary = "Convert a `!torch.vtensor` to a `tensor`";
|
||||
let description = [{
|
||||
This op only operates on ValueTensorType, to avoid conflating conversions
|
||||
|
|
|
@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry ®istry);
|
|||
/// boundary (which currently consist only of builtin types).
|
||||
void setupBackendTypeConversion(ConversionTarget &target,
|
||||
TypeConverter &typeConverter);
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
|
||||
TypeConverter &typeConverter);
|
||||
#endif
|
||||
} // namespace TorchConversion
|
||||
} // namespace torch
|
||||
} // namespace mlir
|
||||
|
|
|
@ -48,6 +48,13 @@ struct StablehloBackendPipelineOptions
|
|||
|
||||
void createTorchBackendToStablehloBackendPipeline(
|
||||
OpPassManager &pm, const StablehloBackendPipelineOptions &options);
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createFuncBackendTypeConversionForStablehloPass();
|
||||
|
||||
std::unique_ptr<InterfacePass<FunctionOpInterface>>
|
||||
createFinalizingBackendTypeConversionForStablehloPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createVerifyStablehloBackendContractPass();
|
||||
#endif
|
||||
|
|
|
@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
|
|||
}];
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> {
|
||||
let summary = "Convert functions to operate on builtin tensors for stablehlo backend";
|
||||
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()";
|
||||
let description = [{
|
||||
Partial type conversion pass analogous in scope to the upstream
|
||||
`func-bufferize` pass. See details there.
|
||||
}];
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
def FinalizingBackendTypeConversion
|
||||
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
|
||||
let summary = "Finalizes a partial conversion to builtin tensors";
|
||||
|
@ -32,6 +43,19 @@ def FinalizingBackendTypeConversion
|
|||
}];
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
def FinalizingBackendTypeConversionForStablehlo
|
||||
: InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> {
|
||||
let summary = "Finalizes a partial conversion to builtin tensors for stablehlo";
|
||||
let constructor =
|
||||
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()";
|
||||
let description = [{
|
||||
Analogous in scope to the upstream `finalizing-bufferize` pass.
|
||||
See details there.
|
||||
}];
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
|
||||
let summary = "Verifies conformity to the linalg-on-tensors backend contract";
|
||||
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";
|
||||
|
|
|
@ -1197,6 +1197,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||
Value input = payloadArgs[0];
|
||||
Type inputElementType =
|
||||
cast<BaseTensorType>(atenToDtype.getSelf().getType()).getDtype();
|
||||
Type dtype =
|
||||
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
|
||||
.getElementType();
|
||||
|
@ -1215,7 +1217,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
}
|
||||
resultElementType = *maybeResultElementType;
|
||||
Value result = convertScalarToDtype(b, loc, input, dtype,
|
||||
/*srcOriginalDtype=*/std::nullopt,
|
||||
/*srcOriginalDtype=*/inputElementType,
|
||||
/*dstOriginalDtype=*/resultElementType);
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -277,8 +277,8 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
||||
if (!inputType)
|
||||
|
||||
op.emitError("only Tensor types supported in StableHLO");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.getA();
|
||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||
|
@ -290,14 +290,24 @@ public:
|
|||
for (int64_t i = 0; i < inputRank; i++)
|
||||
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
|
||||
|
||||
// handle unsigned interger
|
||||
if (inputType.getElementType().isUnsignedInteger()) {
|
||||
input = rewriter.create<stablehlo::ConvertOp>(
|
||||
loc, input,
|
||||
rewriter.getIntegerType(
|
||||
inputType.getElementType().getIntOrFloatBitWidth()));
|
||||
}
|
||||
|
||||
Value constantZero =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
SmallVector<Value> indices(inputRank, constantZero);
|
||||
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
|
||||
Type resultType =
|
||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result,
|
||||
resultType, inputDtype));
|
||||
rewriter.replaceOp(
|
||||
op,
|
||||
convertScalarToDtype(rewriter, loc, result, resultType, inputDtype,
|
||||
/*srcOriginalDtype=*/inputType.getElementType()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -900,7 +900,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
|||
for (int64_t i = maxIndexRank; i < inputRank; ++i) {
|
||||
updateWindowDims.push_back(i);
|
||||
}
|
||||
llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n";
|
||||
|
||||
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
|
||||
rewriter.getContext(),
|
||||
/*updateWindowDims=*/updateWindowDims,
|
||||
|
|
|
@ -51,7 +51,8 @@ public:
|
|||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||
typeConverter);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
|
|
|
@ -23,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
|
|||
if (lhs.hasRank() != rhs.hasRank())
|
||||
return false;
|
||||
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
|
||||
bool sameElementType = lhs.getElementType() == rhs.getElementType();
|
||||
bool sameElementType = false;
|
||||
// Namely, it is worth mentioning that the backends can have different
|
||||
// expectations for signedness when converting from and to the builtin MLIR
|
||||
// types. Therefore, the verifier cannot expect the input and output types to
|
||||
// match in their signedness.
|
||||
if (isa<IntegerType>(lhs.getElementType()) &&
|
||||
isa<IntegerType>(rhs.getElementType())) {
|
||||
sameElementType = lhs.getElementType().getIntOrFloatBitWidth() ==
|
||||
rhs.getElementType().getIntOrFloatBitWidth();
|
||||
} else {
|
||||
sameElementType = lhs.getElementType() == rhs.getElementType();
|
||||
}
|
||||
return sameElementType && sameSize;
|
||||
}
|
||||
|
||||
|
@ -42,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto resultType =
|
||||
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
|
||||
if (!resultType)
|
||||
return failure();
|
||||
inferredReturnTypes.push_back(resultType);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FromBuiltinTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
|
|||
// Type conversion setup.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void
|
||||
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||
TypeConverter &typeConverter) {
|
||||
using ValueTensorTypeConversionFn =
|
||||
std::function<std::optional<Type>(Torch::ValueTensorType)>;
|
||||
|
||||
static void setupValueTensorToBuiltinTensorConversion(
|
||||
ConversionTarget &target, TypeConverter &typeConverter,
|
||||
const ValueTensorTypeConversionFn &conversionFn) {
|
||||
target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
|
||||
TorchConversion::FromBuiltinTensorOp>();
|
||||
typeConverter.addConversion(
|
||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||
return type.toBuiltinTensor();
|
||||
});
|
||||
typeConverter.addConversion(conversionFn);
|
||||
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
|
||||
ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
|
||||
return {};
|
||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
||||
return builder.create<ToBuiltinTensorOp>(loc, type, inputs[0]);
|
||||
});
|
||||
auto sourceMaterialization = [](OpBuilder &builder,
|
||||
Torch::ValueTensorType type,
|
||||
|
@ -162,9 +162,34 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
|||
|
||||
void mlir::torch::TorchConversion::setupBackendTypeConversion(
|
||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter);
|
||||
auto valueTensorTypeConversion =
|
||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||
return type.toBuiltinTensor();
|
||||
};
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||
valueTensorTypeConversion);
|
||||
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||
setupTorchIntToI64Conversion(target, typeConverter);
|
||||
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
|
||||
ConversionTarget &target, TypeConverter &typeConverter) {
|
||||
auto valueTensorTypeConversion =
|
||||
[](Torch::ValueTensorType type) -> std::optional<Type> {
|
||||
auto builtinType = type.toBuiltinTensor();
|
||||
if (type.getDtype().isUnsignedInteger()) {
|
||||
return builtinType.clone(type.getDtype());
|
||||
}
|
||||
return builtinType;
|
||||
};
|
||||
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
|
||||
valueTensorTypeConversion);
|
||||
setupTorchBoolToI1Conversion(target, typeConverter);
|
||||
setupTorchIntToI64Conversion(target, typeConverter);
|
||||
setupTorchFloatToF64Conversion(target, typeConverter);
|
||||
setupTorchGeneratorToI64Conversion(target, typeConverter);
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -26,25 +26,12 @@ using namespace mlir::torch::TorchConversion;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct FuncBackendTypeConversionPass
|
||||
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
||||
using FuncBackendTypeConversionBase<
|
||||
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TorchConversion::TorchConversionDialect>();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
|
||||
patterns, typeConverter);
|
||||
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target) {
|
||||
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
|
||||
typeConverter);
|
||||
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
|
@ -63,11 +50,60 @@ struct FuncBackendTypeConversionPass
|
|||
typeConverter) ||
|
||||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
|
||||
});
|
||||
}
|
||||
|
||||
struct FuncBackendTypeConversionPass
|
||||
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
|
||||
using FuncBackendTypeConversionBase<
|
||||
FuncBackendTypeConversionPass>::FuncBackendTypeConversionBase;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TorchConversion::TorchConversionDialect>();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
struct FuncBackendTypeConversionForStablehloPass
|
||||
: public FuncBackendTypeConversionForStablehloBase<
|
||||
FuncBackendTypeConversionForStablehloPass> {
|
||||
using FuncBackendTypeConversionForStablehloBase<
|
||||
FuncBackendTypeConversionForStablehloPass>::
|
||||
FuncBackendTypeConversionForStablehloBase;
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TorchConversion::TorchConversionDialect>();
|
||||
}
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||
typeConverter);
|
||||
|
||||
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -75,6 +111,13 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
|
|||
return std::make_unique<FuncBackendTypeConversionPass>();
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
|
||||
createFuncBackendTypeConversionForStablehloPass() {
|
||||
return std::make_unique<FuncBackendTypeConversionForStablehloPass>();
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FinalizingBackendTypeConversionPass
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -170,9 +213,61 @@ struct FinalizingBackendTypeConversionPass
|
|||
stripTorchAttrs(func);
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
struct FinalizingBackendTypeConversionForStablehloPass
|
||||
: public FinalizingBackendTypeConversionForStablehloBase<
|
||||
FinalizingBackendTypeConversionForStablehloPass> {
|
||||
using FinalizingBackendTypeConversionForStablehloBase<
|
||||
FinalizingBackendTypeConversionForStablehloPass>::
|
||||
FinalizingBackendTypeConversionForStablehloBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversionForStablehlo(target,
|
||||
typeConverter);
|
||||
|
||||
// Mark materializations as illegal in this pass (since we are finalizing)
|
||||
// and add patterns that eliminate them.
|
||||
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
|
||||
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
|
||||
GeneratorToI64Op>(target, patterns, typeConverter);
|
||||
|
||||
// If all result types are legal, and all block arguments are legal, then
|
||||
// all types in the program are legal.
|
||||
//
|
||||
// We also check that the operand types are legal to avoid creating invalid
|
||||
// IR. For example, this prevents the patterns from updating
|
||||
// the types of the operands to a return op without updating the enclosing
|
||||
// function.
|
||||
target.markUnknownOpDynamicallyLegal(
|
||||
[&](Operation *op) { return typeConverter.isLegal(op); });
|
||||
|
||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
|
||||
// Drop attributes that are no longer used after conversion out of Torch.
|
||||
stripTorchAttrs(func);
|
||||
}
|
||||
};
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<InterfacePass<FunctionOpInterface>>
|
||||
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
|
||||
return std::make_unique<FinalizingBackendTypeConversionPass>();
|
||||
}
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
|
||||
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
|
||||
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() {
|
||||
return std::make_unique<FinalizingBackendTypeConversionForStablehloPass>();
|
||||
}
|
||||
#endif // TORCH_MLIR_ENABLE_STABLEHLO
|
||||
|
|
|
@ -148,10 +148,11 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
|
|||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// StableHLO backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
pm.addPass(
|
||||
TorchConversion::createFuncBackendTypeConversionForStablehloPass());
|
||||
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<func::FuncOp>(
|
||||
TorchConversion::createFinalizingBackendTypeConversionPass());
|
||||
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass());
|
||||
|
||||
// Verify that we have lowered to Stablehlo ops.
|
||||
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());
|
||||
|
|
|
@ -66,6 +66,7 @@ void mlir::torch::registerAllPasses() {
|
|||
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
|
||||
mlir::stablehlo::registerStablehloAggressiveSimplificationPass();
|
||||
mlir::stablehlo::registerStablehloRefineShapesPass();
|
||||
mlir::stablehlo::registerStablehloConvertToSignlessPass();
|
||||
#endif
|
||||
|
||||
#ifdef TORCH_MLIR_ENABLE_REFBACKEND
|
||||
|
|
|
@ -826,6 +826,8 @@ STABLEHLO_PASS_SET = {
|
|||
"SplitWithSizes_Module_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
"EmptyModule_uint8",
|
||||
"TypeConversionUint8ToF32Module_basic",
|
||||
"AtenLinear1D_basic",
|
||||
"AtenLinear2D_basic",
|
||||
"AtenLinear3DBias_basic",
|
||||
|
|
|
@ -23,6 +23,7 @@ STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
|
|||
[
|
||||
"func.func(stablehlo-aggressive-simplification)",
|
||||
"stablehlo-legalize-to-linalg",
|
||||
"stablehlo-convert-to-signless",
|
||||
"canonicalize",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -136,6 +136,26 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
|
|||
module.forward(tensor)
|
||||
|
||||
|
||||
class TypeConversionUint8ToF32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([3], torch.uint8, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
return x.to(torch.float)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TypeConversionUint8ToF32Module())
|
||||
def TypeConversionUint8ToF32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([0, 1, 255]).to(torch.uint8))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue