mirror of https://github.com/llvm/torch-mlir
Support stash_type attribute for onnx.LayerNormalization
parent
878f9929db
commit
75b208bcfd
|
@ -2543,7 +2543,34 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
binder.s64IntegerAttr(stashType, "stash_type", 1))
|
||||
return failure();
|
||||
|
||||
std::optional<int64_t> stashTypeIntTorch =
|
||||
onnxDtypeIntToTorchDtypeInt(stashType);
|
||||
if (!stashTypeIntTorch.has_value())
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented support for the given stash_type");
|
||||
FailureOr<Type> stashDtype = Torch::getTypeForScalarType(
|
||||
binder.op->getContext(),
|
||||
(torch_upstream::ScalarType)stashTypeIntTorch.value());
|
||||
if (failed(stashDtype))
|
||||
return failure();
|
||||
|
||||
// Convert dtype if stash_type is different from input dtype
|
||||
auto xType = cast<Torch::ValueTensorType>(x.getType());
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
if (*stashDtype != xType.getOptionalDtype()) {
|
||||
auto newXType =
|
||||
xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype);
|
||||
Value dtypeValue = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(),
|
||||
rewriter.getI64IntegerAttr(stashTypeIntTorch.value()));
|
||||
x = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
binder.getLoc(), newXType, x, /*dtype=*/dtypeValue,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
}
|
||||
|
||||
Value constEpsilon = rewriter.create<Torch::ConstantFloatOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::FloatType>(),
|
||||
rewriter.getF64FloatAttr(epsilon));
|
||||
|
@ -2566,33 +2593,43 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
|
||||
normalized);
|
||||
|
||||
int64_t numResults = binder.op->getNumResults();
|
||||
if (numResults == 1) {
|
||||
SmallVector<int64_t> reducedShape(rank, 1);
|
||||
for (int64_t i = 0; i < axis; i++)
|
||||
reducedShape[i] = xShape[i];
|
||||
auto reducedType = xType.getWithSizesAndDtype(
|
||||
reducedShape, xType.getOptionalDtype());
|
||||
Value y = rewriter
|
||||
.create<Torch::AtenNativeLayerNormOp>(
|
||||
auto reducedType =
|
||||
xType.getWithSizesAndDtype(reducedShape, *stashDtype);
|
||||
auto y = rewriter.create<Torch::AtenNativeLayerNormOp>(
|
||||
binder.getLoc(), yType, /*meanType=*/reducedType,
|
||||
/*invStdDevType=*/reducedType, x, normalized_shape,
|
||||
scale, b, constEpsilon)
|
||||
.getResult0();
|
||||
rewriter.replaceOp(binder.op, y);
|
||||
/*invStdDevType=*/reducedType, x, normalized_shape, scale, b,
|
||||
constEpsilon);
|
||||
|
||||
int64_t numResults = binder.op->getNumResults();
|
||||
if (numResults == 1) {
|
||||
rewriter.replaceOp(binder.op, y.getResult0());
|
||||
return success();
|
||||
}
|
||||
if (numResults == 3) {
|
||||
|
||||
Value meanOutput = y.getResult1();
|
||||
Value varOutput = y.getResult2();
|
||||
// Convert meanType and varType back if stash_dtype is different
|
||||
if (binder.tensorResultTypeAtIndex(meanType, 1) ||
|
||||
binder.tensorResultTypeAtIndex(invStdDevType, 2))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenNativeLayerNormOp>(
|
||||
binder.op, yType, meanType, invStdDevType, x, normalized_shape,
|
||||
scale, b, constEpsilon);
|
||||
return success();
|
||||
if (*stashDtype != meanType.getOptionalDtype()) {
|
||||
Value constDtype = Torch::getDtypeIntValueForType(
|
||||
rewriter, binder.getLoc(), meanType.getDtype());
|
||||
meanOutput = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
varOutput = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype,
|
||||
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
}
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unimplemented: expected either 1 or 3 results");
|
||||
rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput});
|
||||
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("LeakyRelu", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -6635,7 +6635,7 @@ class DecomposeAtenNativeLayerNormOp
|
|||
Location loc = op.getLoc();
|
||||
auto context = op.getContext();
|
||||
|
||||
auto inputTy = cast<BaseTensorType>(op.getInput().getType());
|
||||
auto inputTy = cast<ValueTensorType>(op.getInput().getType());
|
||||
if (!inputTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "input tensor should have known sizes.");
|
||||
|
@ -6690,6 +6690,18 @@ class DecomposeAtenNativeLayerNormOp
|
|||
loc, inputTy, inputRsqrtVar, op.getInput());
|
||||
Value inputNormalized = rewriter.create<AtenMulTensorOp>(
|
||||
loc, inputTy, inputZeroMean, inputRsqrtVarExpanded);
|
||||
// Convert resultType if dtype is different
|
||||
auto resultTensorType =
|
||||
dyn_cast<ValueTensorType>(op.getResult(0).getType());
|
||||
if (inputTy.getDtype() != resultTensorType.getDtype()) {
|
||||
Value dtypeValue = Torch::getDtypeIntValueForType(
|
||||
rewriter, loc, resultTensorType.getDtype());
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
inputNormalized = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
loc, resultTensorType, inputNormalized,
|
||||
/*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse,
|
||||
/*memory_format=*/none);
|
||||
}
|
||||
Value out = rewriter.create<TensorStaticInfoCastOp>(
|
||||
loc, op.getResult(0).getType(), inputNormalized);
|
||||
|
||||
|
|
Loading…
Reference in New Issue