mirror of https://github.com/llvm/torch-mlir
Adds result types to a prelu decomp (#3098)
This adds explicit result types instead of relying on shape/dtype computations. Solves a regression issue with IREE: #3092pull/3100/head
parent
6cbb2f7ae0
commit
40e762ca42
|
@ -2382,15 +2382,15 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
Value input = op.getSelf();
|
||||
Value weight = op.getWeight();
|
||||
auto resType = op.getType().cast<BaseTensorType>();
|
||||
auto baseType =
|
||||
ValueTensorType::getWithLeastStaticInformation(op.getContext());
|
||||
auto resType = op.getType().cast<ValueTensorType>();
|
||||
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
||||
resType.getOptionalSizes(), rewriter.getI1Type());
|
||||
Value zero =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||
Value inputMulWeight =
|
||||
rewriter.create<AtenMulTensorOp>(loc, baseType, input, weight);
|
||||
rewriter.create<AtenMulTensorOp>(loc, resType, input, weight);
|
||||
Value lessThanZero =
|
||||
rewriter.create<AtenLtScalarOp>(loc, baseType, input, zero);
|
||||
rewriter.create<AtenLtScalarOp>(loc, boolTensorType, input, zero);
|
||||
Value preluOutput = rewriter.create<AtenWhereSelfOp>(
|
||||
loc, resType, lessThanZero, inputMulWeight, input);
|
||||
|
||||
|
|
Loading…
Reference in New Issue