[Stablehlo] support uint8 (#3367)

Support lowering unsigned integer type to stablehlo as discussed in
https://github.com/llvm/torch-mlir/pull/2184.

The things I do in this PR:
1. create `setupBackendTypeConversionForStablehlo()`,
`createFuncBackendTypeConversionForStablehloPass` and
`createFinalizingBackendTypeConversionForStablehloPass`.
2. remove `InferTypeOpInterface` from `torch_c.to_builtin_tensor`,
because it's different result type between linalg backend and stablehlo
backend:
```
// linalg backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xi8>
    %0 = tensor.empty() : tensor<3xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<3xi8>) outs(%0 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %2 = arith.uitofp %in : i8 to f32
      linalg.yield %2 : f32
    } -> tensor<3xf32>
    return %1 : tensor<3xf32>
}
// stablehlo backend
func.func @forward(%arg0: !torch.vtensor<[3],ui8>) -> tensor<3xf32> {
    %c = torch_c.to_builtin_tensor %arg0 : (!torch.vtensor<[3], ui8> -> tensor<3xui8>
    %0 = stablehlo.convert %arg0 : (tensor<3xui8> -> tensor<3xf32>
    return %0 : tensor<3xf32>
}
```
3. fix stablehlo and linalg's conversion
pull/3414/head
Yuanqiang Liu 2024-06-04 09:04:59 +08:00 committed by GitHub
parent 56d21cba62
commit 50f7103098
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 245 additions and 54 deletions

2
externals/stablehlo vendored

@ -1 +1 @@
Subproject commit c44d9af8d4879adccf1054cb61a53377ae5898cb Subproject commit 25d237f6273361bb29e8436349c7067ee559dca2

View File

@ -25,9 +25,7 @@ class TorchConversion_Op<string mnemonic, list<Trait> traits = []>
// Conversions to backend types. // Conversions to backend types.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor", [ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor"> {
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Convert a `!torch.vtensor` to a `tensor`"; let summary = "Convert a `!torch.vtensor` to a `tensor`";
let description = [{ let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions This op only operates on ValueTensorType, to avoid conflating conversions

View File

@ -25,6 +25,11 @@ void getBackendTypeConversionDependentDialects(DialectRegistry &registry);
/// boundary (which currently consist only of builtin types). /// boundary (which currently consist only of builtin types).
void setupBackendTypeConversion(ConversionTarget &target, void setupBackendTypeConversion(ConversionTarget &target,
TypeConverter &typeConverter); TypeConverter &typeConverter);
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
TypeConverter &typeConverter);
#endif
} // namespace TorchConversion } // namespace TorchConversion
} // namespace torch } // namespace torch
} // namespace mlir } // namespace mlir

View File

@ -48,6 +48,13 @@ struct StablehloBackendPipelineOptions
void createTorchBackendToStablehloBackendPipeline( void createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm, const StablehloBackendPipelineOptions &options); OpPassManager &pm, const StablehloBackendPipelineOptions &options);
std::unique_ptr<OperationPass<ModuleOp>>
createFuncBackendTypeConversionForStablehloPass();
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createFinalizingBackendTypeConversionForStablehloPass();
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
createVerifyStablehloBackendContractPass(); createVerifyStablehloBackendContractPass();
#endif #endif

View File

@ -21,6 +21,17 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu
}]; }];
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def FuncBackendTypeConversionForStablehlo : Pass<"torch-func-backend-type-conversion-for-stablehlo", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors for stablehlo backend";
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForStablehloPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO
def FinalizingBackendTypeConversion def FinalizingBackendTypeConversion
: InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors"; let summary = "Finalizes a partial conversion to builtin tensors";
@ -32,6 +43,19 @@ def FinalizingBackendTypeConversion
}]; }];
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def FinalizingBackendTypeConversionForStablehlo
: InterfacePass<"torch-finalizing-backend-type-conversion-for-stablehlo", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors for stablehlo";
let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForStablehloPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO
def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> { def VerifyLinalgOnTensorsBackendContract : Pass<"torch-verify-linalg-on-tensors-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the linalg-on-tensors backend contract"; let summary = "Verifies conformity to the linalg-on-tensors backend contract";
let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()"; let constructor = "mlir::torch::TorchConversion::createVerifyLinalgOnTensorsBackendContractPass()";

View File

@ -1197,6 +1197,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) { if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
Value input = payloadArgs[0]; Value input = payloadArgs[0];
Type inputElementType =
cast<BaseTensorType>(atenToDtype.getSelf().getType()).getDtype();
Type dtype = Type dtype =
cast<RankedTensorType>(converter->convertType(atenToDtype.getType())) cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
.getElementType(); .getElementType();
@ -1215,7 +1217,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
} }
resultElementType = *maybeResultElementType; resultElementType = *maybeResultElementType;
Value result = convertScalarToDtype(b, loc, input, dtype, Value result = convertScalarToDtype(b, loc, input, dtype,
/*srcOriginalDtype=*/std::nullopt, /*srcOriginalDtype=*/inputElementType,
/*dstOriginalDtype=*/resultElementType); /*dstOriginalDtype=*/resultElementType);
return result; return result;
} }

View File

@ -277,8 +277,8 @@ public:
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType()); auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
if (!inputType) if (!inputType)
op.emitError("only Tensor types supported in StableHLO"); op.emitError("only Tensor types supported in StableHLO");
Location loc = op.getLoc(); Location loc = op.getLoc();
Value input = adaptor.getA(); Value input = adaptor.getA();
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input); SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
@ -290,14 +290,24 @@ public:
for (int64_t i = 0; i < inputRank; i++) for (int64_t i = 0; i < inputRank; i++)
checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne);
// handle unsigned interger
if (inputType.getElementType().isUnsignedInteger()) {
input = rewriter.create<stablehlo::ConvertOp>(
loc, input,
rewriter.getIntegerType(
inputType.getElementType().getIntOrFloatBitWidth()));
}
Value constantZero = Value constantZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
SmallVector<Value> indices(inputRank, constantZero); SmallVector<Value> indices(inputRank, constantZero);
Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices); Value result = rewriter.create<tensor::ExtractOp>(loc, input, indices);
Type resultType = Type resultType =
this->getTypeConverter()->convertType(op->getResult(0).getType()); this->getTypeConverter()->convertType(op->getResult(0).getType());
rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, rewriter.replaceOp(
resultType, inputDtype)); op,
convertScalarToDtype(rewriter, loc, result, resultType, inputDtype,
/*srcOriginalDtype=*/inputType.getElementType()));
return success(); return success();
} }
}; };

View File

@ -900,7 +900,7 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
for (int64_t i = maxIndexRank; i < inputRank; ++i) { for (int64_t i = maxIndexRank; i < inputRank; ++i) {
updateWindowDims.push_back(i); updateWindowDims.push_back(i);
} }
llvm::outs() << "maxIndexRank: " << maxIndexRank << "\n";
auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get( auto scatterDimensionNumbers = stablehlo::ScatterDimensionNumbersAttr::get(
rewriter.getContext(), rewriter.getContext(),
/*updateWindowDims=*/updateWindowDims, /*updateWindowDims=*/updateWindowDims,

View File

@ -51,7 +51,8 @@ public:
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversionForStablehlo(target,
typeConverter);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);

View File

@ -23,7 +23,18 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
if (lhs.hasRank() != rhs.hasRank()) if (lhs.hasRank() != rhs.hasRank())
return false; return false;
bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true; bool sameSize = lhs.hasRank() ? lhs.getShape().equals(rhs.getShape()) : true;
bool sameElementType = lhs.getElementType() == rhs.getElementType(); bool sameElementType = false;
// Namely, it is worth mentioning that the backends can have different
// expectations for signedness when converting from and to the builtin MLIR
// types. Therefore, the verifier cannot expect the input and output types to
// match in their signedness.
if (isa<IntegerType>(lhs.getElementType()) &&
isa<IntegerType>(rhs.getElementType())) {
sameElementType = lhs.getElementType().getIntOrFloatBitWidth() ==
rhs.getElementType().getIntOrFloatBitWidth();
} else {
sameElementType = lhs.getElementType() == rhs.getElementType();
}
return sameElementType && sameSize; return sameElementType && sameSize;
} }
@ -42,18 +53,6 @@ LogicalResult ToBuiltinTensorOp::verify() {
return success(); return success();
} }
LogicalResult ToBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto resultType =
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
if (!resultType)
return failure();
inferredReturnTypes.push_back(resultType);
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FromBuiltinTensorOp // FromBuiltinTensorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -23,22 +23,22 @@ void mlir::torch::TorchConversion::getBackendTypeConversionDependentDialects(
// Type conversion setup. // Type conversion setup.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void using ValueTensorTypeConversionFn =
setupValueTensorToBuiltinTensorConversion(ConversionTarget &target, std::function<std::optional<Type>(Torch::ValueTensorType)>;
TypeConverter &typeConverter) {
static void setupValueTensorToBuiltinTensorConversion(
ConversionTarget &target, TypeConverter &typeConverter,
const ValueTensorTypeConversionFn &conversionFn) {
target.addLegalOp<TorchConversion::ToBuiltinTensorOp, target.addLegalOp<TorchConversion::ToBuiltinTensorOp,
TorchConversion::FromBuiltinTensorOp>(); TorchConversion::FromBuiltinTensorOp>();
typeConverter.addConversion( typeConverter.addConversion(conversionFn);
[](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor();
});
typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type, typeConverter.addTargetMaterialization([](OpBuilder &builder, TensorType type,
ValueRange inputs, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
if (!isa<Torch::BaseTensorType>(inputs[0].getType())) if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
return {}; return {};
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]); return builder.create<ToBuiltinTensorOp>(loc, type, inputs[0]);
}); });
auto sourceMaterialization = [](OpBuilder &builder, auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type, Torch::ValueTensorType type,
@ -162,9 +162,34 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
void mlir::torch::TorchConversion::setupBackendTypeConversion( void mlir::torch::TorchConversion::setupBackendTypeConversion(
ConversionTarget &target, TypeConverter &typeConverter) { ConversionTarget &target, TypeConverter &typeConverter) {
setupValueTensorToBuiltinTensorConversion(target, typeConverter); auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
return type.toBuiltinTensor();
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
setupTorchBoolToI1Conversion(target, typeConverter); setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter); setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter); setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter); setupTorchGeneratorToI64Conversion(target, typeConverter);
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
ConversionTarget &target, TypeConverter &typeConverter) {
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
auto builtinType = type.toBuiltinTensor();
if (type.getDtype().isUnsignedInteger()) {
return builtinType.clone(type.getDtype());
}
return builtinType;
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}
#endif

View File

@ -26,6 +26,32 @@ using namespace mlir::torch::TorchConversion;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
void populateFuncBackendTypeConversionPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns,
ConversionTarget &target) {
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
typeConverter);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp>();
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});
}
struct FuncBackendTypeConversionPass struct FuncBackendTypeConversionPass
: public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> { : public FuncBackendTypeConversionBase<FuncBackendTypeConversionPass> {
using FuncBackendTypeConversionBase< using FuncBackendTypeConversionBase<
@ -43,31 +69,41 @@ struct FuncBackendTypeConversionPass
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
patterns, typeConverter);
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return typeConverter.isLegal(op); });
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp>();
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
isLegalForBranchOpInterfaceTypeConversionPattern(op,
typeConverter) ||
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
});
if (failed(applyFullConversion(module, target, std::move(patterns)))) if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure(); signalPassFailure();
} }
}; };
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct FuncBackendTypeConversionForStablehloPass
: public FuncBackendTypeConversionForStablehloBase<
FuncBackendTypeConversionForStablehloPass> {
using FuncBackendTypeConversionForStablehloBase<
FuncBackendTypeConversionForStablehloPass>::
FuncBackendTypeConversionForStablehloBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversionForStablehlo(target,
typeConverter);
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};
#endif // TORCH_MLIR_ENABLE_STABLEHLO
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
@ -75,6 +111,13 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
return std::make_unique<FuncBackendTypeConversionPass>(); return std::make_unique<FuncBackendTypeConversionPass>();
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
createFuncBackendTypeConversionForStablehloPass() {
return std::make_unique<FuncBackendTypeConversionForStablehloPass>();
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// FinalizingBackendTypeConversionPass // FinalizingBackendTypeConversionPass
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -170,9 +213,61 @@ struct FinalizingBackendTypeConversionPass
stripTorchAttrs(func); stripTorchAttrs(func);
} }
}; };
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct FinalizingBackendTypeConversionForStablehloPass
: public FinalizingBackendTypeConversionForStablehloBase<
FinalizingBackendTypeConversionForStablehloPass> {
using FinalizingBackendTypeConversionForStablehloBase<
FinalizingBackendTypeConversionForStablehloPass>::
FinalizingBackendTypeConversionForStablehloBase;
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversionForStablehlo(target,
typeConverter);
// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
GeneratorToI64Op>(target, patterns, typeConverter);
// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.
//
// We also check that the operand types are legal to avoid creating invalid
// IR. For example, this prevents the patterns from updating
// the types of the operands to a return op without updating the enclosing
// function.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
// Drop attributes that are no longer used after conversion out of Torch.
stripTorchAttrs(func);
}
};
#endif // TORCH_MLIR_ENABLE_STABLEHLO
} // namespace } // namespace
std::unique_ptr<InterfacePass<FunctionOpInterface>> std::unique_ptr<InterfacePass<FunctionOpInterface>>
mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>(); return std::make_unique<FinalizingBackendTypeConversionPass>();
} }
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() {
return std::make_unique<FinalizingBackendTypeConversionForStablehloPass>();
}
#endif // TORCH_MLIR_ENABLE_STABLEHLO

View File

@ -148,10 +148,11 @@ void TorchConversion::createTorchBackendToStablehloBackendPipeline(
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// StableHLO backend contract. // StableHLO backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(
TorchConversion::createFuncBackendTypeConversionForStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass()); pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>( pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass()); TorchConversion::createFinalizingBackendTypeConversionForStablehloPass());
// Verify that we have lowered to Stablehlo ops. // Verify that we have lowered to Stablehlo ops.
pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass());

View File

@ -66,6 +66,7 @@ void mlir::torch::registerAllPasses() {
mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
mlir::stablehlo::registerStablehloAggressiveSimplificationPass(); mlir::stablehlo::registerStablehloAggressiveSimplificationPass();
mlir::stablehlo::registerStablehloRefineShapesPass(); mlir::stablehlo::registerStablehloRefineShapesPass();
mlir::stablehlo::registerStablehloConvertToSignlessPass();
#endif #endif
#ifdef TORCH_MLIR_ENABLE_REFBACKEND #ifdef TORCH_MLIR_ENABLE_REFBACKEND

View File

@ -826,6 +826,8 @@ STABLEHLO_PASS_SET = {
"SplitWithSizes_Module_basic", "SplitWithSizes_Module_basic",
"TensorSplitSections_GetItemModule_basic", "TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic", "TensorSplitSections_ListUnpackModule_basic",
"EmptyModule_uint8",
"TypeConversionUint8ToF32Module_basic",
"AtenLinear1D_basic", "AtenLinear1D_basic",
"AtenLinear2D_basic", "AtenLinear2D_basic",
"AtenLinear3DBias_basic", "AtenLinear3DBias_basic",

View File

@ -23,6 +23,7 @@ STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join(
[ [
"func.func(stablehlo-aggressive-simplification)", "func.func(stablehlo-aggressive-simplification)",
"stablehlo-legalize-to-linalg", "stablehlo-legalize-to-linalg",
"stablehlo-convert-to-signless",
"canonicalize", "canonicalize",
] ]
) )

View File

@ -136,6 +136,26 @@ def TypeConversionI1ToF64Module_basic(module, tu: TestUtils):
module.forward(tensor) module.forward(tensor)
class TypeConversionUint8ToF32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3], torch.uint8, True),
]
)
def forward(self, x):
return x.to(torch.float)
@register_test_case(module_factory=lambda: TypeConversionUint8ToF32Module())
def TypeConversionUint8ToF32Module_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 1, 255]).to(torch.uint8))
# ============================================================================== # ==============================================================================