mirror of https://github.com/llvm/torch-mlir
onnx.LSTM - bidirectional, layout attr (#3771)
- Support Bidirectional LSTM (utilising the forward LSTM layer with flipped Inputs and Outputs) - Support layout 1 - Support default cases for attr `clip` and `input_forget` - Support returning partial outputs (1-3) - fixes for alt_e2e_tests lstm tests (1,2,3)pull/3776/head
parent
58489faf7f
commit
7830c00ca2
|
@ -34,6 +34,7 @@ struct OpBinder {
|
|||
Location getLoc() { return op->getLoc(); }
|
||||
|
||||
int getNumOperands() { return op->getNumOperands(); }
|
||||
int getNumResults() { return op->getNumResults(); }
|
||||
|
||||
// Operand matches of different arities.
|
||||
ParseResult tensorOperand(Value &value0) {
|
||||
|
|
|
@ -661,8 +661,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
std::string direction;
|
||||
|
||||
ValueTensorType yTy, Y_hType, Y_cType;
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1) ||
|
||||
if (binder.tensorResultTypeAtIndex(yTy, 0) &&
|
||||
binder.tensorResultTypeAtIndex(Y_hType, 1) &&
|
||||
binder.tensorResultTypeAtIndex(Y_cType, 2)) {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"At least one outputs must be present");
|
||||
|
@ -686,51 +686,110 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
|
||||
auto xTy = cast<ValueTensorType>(X.getType());
|
||||
auto wTy = cast<ValueTensorType>(W.getType());
|
||||
Value B;
|
||||
if (binder.tensorOperandAtIndex(B, 3)) {
|
||||
B = b.create<AtenZerosOp>(W.getType(), W);
|
||||
}
|
||||
|
||||
// TODO: add defaults for activation_alpha acticvation_beta attributes
|
||||
|
||||
llvm::SmallVector<std::string> activationsList;
|
||||
if (binder.stringArrayAttr(activationsList, "activations"))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Missing required attribute; activations");
|
||||
|
||||
LstmActivations activations;
|
||||
activations.f = "Sigmoid";
|
||||
activations.g = "Tanh";
|
||||
activations.h = "Tanh";
|
||||
if (activationsList.size() == 3) {
|
||||
activations.f = activationsList[0];
|
||||
activations.g = activationsList[1];
|
||||
activations.h = activationsList[2];
|
||||
} else if (activationsList.size() != 0) {
|
||||
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
||||
direction != "forward" && direction != "bidirectional")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "activations must be empty have 3 elements, but " +
|
||||
binder.op, "Unsupported direction attribute value. "
|
||||
"Only 'forward' / 'bidrectional' are supported but '" +
|
||||
direction + "' is provided.");
|
||||
int64_t num_directions = 1 + (direction == "bidirectional");
|
||||
bool isBidirectional = direction == "bidirectional";
|
||||
// There can be backward activations too
|
||||
// if backward -> look for 6 atcivations (what happens when only three?)
|
||||
|
||||
int64_t num_activations = activationsList.size();
|
||||
if (num_activations != 0 && num_activations != 3 && num_activations != 6) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "activations must either be empty (default), have 3 elements"
|
||||
" (forward) or, have 6 elements (bidirectional), but " +
|
||||
std::to_string(activationsList.size()) +
|
||||
" are provided.");
|
||||
}
|
||||
// TODO : Add checks, defaults and fails for inputs - sequence_lens, P and
|
||||
// attrs- clip, input_forget, layout
|
||||
|
||||
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
|
||||
direction != "forward")
|
||||
Value B;
|
||||
if (binder.tensorOperandAtIndex(B, 3)) {
|
||||
Value none = b.create<ConstantNoneOp>();
|
||||
Value cstHiddenx8 = b.create<ConstantIntOp>(
|
||||
b.getType<IntType>(), b.getI64IntegerAttr(8 * hidden_size));
|
||||
Value cstNumDir = b.create<ConstantIntOp>(
|
||||
b.getType<IntType>(), b.getI64IntegerAttr(num_directions));
|
||||
auto BType = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, 8 * hidden_size},
|
||||
cast<ValueTensorType>(W.getType()).getDtype());
|
||||
Value zerosShapeList = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(b.getType<IntType>()),
|
||||
SmallVector<Value>{cstNumDir, cstHiddenx8});
|
||||
B = b.create<AtenZerosOp>(BType, zerosShapeList, none, none, none, none);
|
||||
}
|
||||
|
||||
LstmActivations activations, activationsRev;
|
||||
// Default case (both forward and reverse)
|
||||
activations.f = "Sigmoid";
|
||||
activations.g = "Tanh";
|
||||
activations.h = "Tanh";
|
||||
activationsRev.f = "Sigmoid";
|
||||
activationsRev.g = "Tanh";
|
||||
activationsRev.h = "Tanh";
|
||||
|
||||
// forward only (also to be added for bidirectional case)
|
||||
if (num_activations >= 3) {
|
||||
activations.f = activationsList[0];
|
||||
activations.g = activationsList[1];
|
||||
activations.h = activationsList[2];
|
||||
}
|
||||
|
||||
// bidirectional
|
||||
if (num_activations == 6) {
|
||||
activationsRev.f = activationsList[3];
|
||||
activationsRev.g = activationsList[4];
|
||||
activationsRev.h = activationsList[5];
|
||||
}
|
||||
|
||||
float clip;
|
||||
if (!binder.f32FloatAttr(clip, "clip", 0.0) && clip != 0.0)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"Unsupported direction attribute value. "
|
||||
"Only 'forward' is supported but '" +
|
||||
direction + "' is provided.");
|
||||
int64_t num_directions = 1 + (direction == "bidirectional");
|
||||
"clip attribute not supported");
|
||||
|
||||
int64_t input_forget;
|
||||
if (!binder.s64IntegerAttr(input_forget, "input_forget", 0) &&
|
||||
input_forget != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "only input_forget = 0 supported. Got input_forgt = " +
|
||||
std::to_string(input_forget));
|
||||
|
||||
int64_t layout;
|
||||
if (!binder.s64IntegerAttr(layout, "layout", 0) && layout != 0 && layout != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "invalid value of layout attribute, expecting 0 / 1 got " +
|
||||
std::to_string(layout));
|
||||
|
||||
auto XShape = xTy.getSizes();
|
||||
int64_t batch_size = XShape[1];
|
||||
int64_t seq_len, batch_size;
|
||||
if (layout == 0) {
|
||||
seq_len = XShape[0];
|
||||
batch_size = XShape[1];
|
||||
} else {
|
||||
seq_len = XShape[1];
|
||||
batch_size = XShape[0];
|
||||
}
|
||||
|
||||
int64_t input_size = XShape[2];
|
||||
if (num_directions != wTy.getSizes()[0])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
||||
") does not match the first dimension of wTy (" +
|
||||
std::to_string(wTy.getSizes()[0]) + ")");
|
||||
if (num_directions != 1)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "num_directions (" + std::to_string(num_directions) +
|
||||
") is not equal to 1");
|
||||
|
||||
if (4 * hidden_size != wTy.getSizes()[1])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "4 times hidden_size (" + std::to_string(4 * hidden_size) +
|
||||
|
@ -746,6 +805,13 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value R_forward = getDirection(b, 0, R);
|
||||
Value B_forward = getDirection(b, 0, B);
|
||||
|
||||
Value W_reverse, R_reverse, B_reverse;
|
||||
if (isBidirectional) {
|
||||
W_reverse = getDirection(b, 1, W);
|
||||
R_reverse = getDirection(b, 1, R);
|
||||
B_reverse = getDirection(b, 1, B);
|
||||
}
|
||||
|
||||
auto hTy = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||
xTy.getDtype());
|
||||
|
@ -770,29 +836,44 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
|
||||
Value initial_h;
|
||||
if (binder.tensorOperandAtIndex(initial_h, 5)) {
|
||||
// default created for layout 0
|
||||
initial_h =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1)
|
||||
initial_h = StaticTranspose(b, initial_h, 0, 1);
|
||||
}
|
||||
|
||||
Value initial_c;
|
||||
if (binder.tensorOperandAtIndex(initial_c, 6)) {
|
||||
// default created for layout 0
|
||||
initial_c =
|
||||
b.create<AtenZerosOp>(hTy, hShape, cstDtype, cstNone, cstNone, cstNone);
|
||||
} else {
|
||||
if (layout == 1)
|
||||
initial_c = StaticTranspose(b, initial_c, 0, 1);
|
||||
}
|
||||
|
||||
// convert X from layout 1 to layout 0
|
||||
if (layout == 1)
|
||||
X = StaticTranspose(b, X, 0, 1);
|
||||
|
||||
// X, initial_h, initial_c are now in layout 0
|
||||
|
||||
Value initial_h_forward = getDirection(b, 0, initial_h);
|
||||
Value initial_c_forward = getDirection(b, 0, initial_c);
|
||||
|
||||
if (num_directions != 1) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "Unsupported num_directions. Only 1 is supported but " +
|
||||
std::to_string(num_directions) + " is provided.");
|
||||
// TODO: support bidirectional LSTM by doing both directions and replacing
|
||||
// Unsqueeze with Stack
|
||||
Value initial_h_reverse, initial_c_reverse;
|
||||
if (isBidirectional) {
|
||||
initial_h_reverse = getDirection(b, 1, initial_h);
|
||||
initial_c_reverse = getDirection(b, 1, initial_c);
|
||||
}
|
||||
// Everything hereon is for the forward direction, with the direction
|
||||
// dimention squeezed out.
|
||||
|
||||
LstmWeights weights; // weights and biases
|
||||
// Everything hereon is for the forward direction (unless in bidirectional if
|
||||
// block), with the direction dimention squeezed out and all inputs in layout
|
||||
// 0 format
|
||||
|
||||
LstmWeights weights, weightsRev; // weights and biases
|
||||
|
||||
auto intConst = [&](int64_t val) {
|
||||
return b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(val));
|
||||
|
@ -804,6 +885,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value recurrentWeightsEndIdx = intConst(8 * hidden_size);
|
||||
auto biasType = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{hidden_size * 4}, wTy.getDtype());
|
||||
// forward
|
||||
Value Wb = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_forward,
|
||||
/*dim=*/cstZero,
|
||||
|
@ -816,6 +898,22 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
/*start=*/recurrentWeightsStartIdx,
|
||||
/*end=*/recurrentWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
Value Wb_reverse, Rb_reverse;
|
||||
if (isBidirectional) {
|
||||
// reverse
|
||||
Wb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_reverse,
|
||||
/*dim=*/cstZero,
|
||||
/*start=*/cstZero,
|
||||
/*end=*/inputWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
Rb_reverse = b.create<AtenSliceTensorOp>(biasType,
|
||||
/*input=*/B_reverse,
|
||||
/*dim=*/cstZero,
|
||||
/*start=*/recurrentWeightsStartIdx,
|
||||
/*end=*/recurrentWeightsEndIdx,
|
||||
/*step=*/cstOne);
|
||||
}
|
||||
|
||||
// gate splitting
|
||||
auto gateBiasType = b.getType<ValueTensorType>(
|
||||
|
@ -833,61 +931,164 @@ LogicalResult OnnxLstmExpander(OpBinder binder,
|
|||
Value forgetGateWeightsEndIdx = intConst(3 * hidden_size);
|
||||
Value cellGateWeightsEndIdx = intConst(4 * hidden_size);
|
||||
|
||||
auto sliceIOFC = [&](std::function<Value(Value, Value)> slicerFunction) {
|
||||
auto sliceIOFC = [&](std::function<Value(Value, Value, Value)> slicerFunction,
|
||||
Value WoB) {
|
||||
// slice into 4 components and return tuple
|
||||
return std::make_tuple(
|
||||
slicerFunction(cstZero, inputGateWeightsEndIdx),
|
||||
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx),
|
||||
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx),
|
||||
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx));
|
||||
slicerFunction(cstZero, inputGateWeightsEndIdx, WoB),
|
||||
slicerFunction(inputGateWeightsEndIdx, outputGateWeightsEndIdx, WoB),
|
||||
slicerFunction(outputGateWeightsEndIdx, forgetGateWeightsEndIdx, WoB),
|
||||
slicerFunction(forgetGateWeightsEndIdx, cellGateWeightsEndIdx, WoB));
|
||||
};
|
||||
|
||||
auto sliceGateBias = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, Wb, cstZero, startIdx,
|
||||
auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||
endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) =
|
||||
sliceIOFC(sliceGateBias);
|
||||
sliceIOFC(sliceGateBias, Wb);
|
||||
|
||||
auto sliceGateBiasR = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, Rb, cstZero, startIdx,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.Wb_i, weightsRev.Wb_o, weightsRev.Wb_f,
|
||||
weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse);
|
||||
|
||||
auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateBiasType, WoB, cstZero, startIdx,
|
||||
endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) =
|
||||
sliceIOFC(sliceGateBiasR);
|
||||
sliceIOFC(sliceGateBiasR, Rb);
|
||||
|
||||
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, W_forward, cstZero,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.Rb_i, weightsRev.Rb_o, weightsRev.Rb_f,
|
||||
weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse);
|
||||
|
||||
auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeIH, WoB, cstZero,
|
||||
startIdx, endIdx, cstOne);
|
||||
};
|
||||
std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) =
|
||||
sliceIOFC(sliceGateWeightsIH);
|
||||
sliceIOFC(sliceGateWeightsIH, W_forward);
|
||||
|
||||
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, R_forward, cstZero,
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.W_i, weightsRev.W_o, weightsRev.W_f, weightsRev.W_c) =
|
||||
sliceIOFC(sliceGateWeightsIH, W_reverse);
|
||||
|
||||
auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) {
|
||||
return b.create<AtenSliceTensorOp>(gateWeightsTypeHH, WoB, cstZero,
|
||||
startIdx, endIdx, cstOne);
|
||||
};
|
||||
|
||||
std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) =
|
||||
sliceIOFC(sliceGateWeightsHH);
|
||||
sliceIOFC(sliceGateWeightsHH, R_forward);
|
||||
|
||||
if (isBidirectional)
|
||||
std::tie(weightsRev.R_i, weightsRev.R_o, weightsRev.R_f, weightsRev.R_c) =
|
||||
sliceIOFC(sliceGateWeightsHH, R_reverse);
|
||||
|
||||
LstmLayerOutput lstmLayerOutput = lstm_layer(
|
||||
b, X, initial_h_forward, initial_c_forward, weights, activations);
|
||||
|
||||
auto Y_h_Y_c_unsqueezed_type = b.getType<ValueTensorType>(
|
||||
Value Y_h_result, Y_c_result, Y_result;
|
||||
|
||||
// if forward (unidirectional) unsqueeze and output
|
||||
auto YallDtype =
|
||||
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype();
|
||||
auto Y_h_Y_c_uni_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1, batch_size, hidden_size}, YallDtype);
|
||||
auto Y_uni_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, 1, batch_size, hidden_size},
|
||||
YallDtype);
|
||||
auto Y_h_Y_c_res_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
|
||||
cast<ValueTensorType>(lstmLayerOutput.Y_h.getType()).getDtype());
|
||||
Value Y_h_unsqueezed = b.create<AtenUnsqueezeOp>(
|
||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_h, cstZero);
|
||||
Value Y_c_unsqueezed = b.create<AtenUnsqueezeOp>(
|
||||
Y_h_Y_c_unsqueezed_type, lstmLayerOutput.Y_c, cstZero);
|
||||
YallDtype);
|
||||
auto Y_res_type = b.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{seq_len, num_directions, batch_size,
|
||||
hidden_size},
|
||||
YallDtype);
|
||||
|
||||
Value Y_h_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero);
|
||||
|
||||
Value Y_c_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero);
|
||||
|
||||
// unsqueeze num_directions dim1 of Y
|
||||
// to create the onnx.LSTM output shape [seq_length, num_directions,
|
||||
// batch_size, hidden_size]
|
||||
Value Y_unsqueezed =
|
||||
b.create<AtenUnsqueezeOp>(yTy, lstmLayerOutput.Y, cstOne);
|
||||
Value Y_forward =
|
||||
b.create<AtenUnsqueezeOp>(Y_uni_type, lstmLayerOutput.Y, cstOne);
|
||||
|
||||
rewriter.replaceOp(binder.op, mlir::ValueRange{Y_unsqueezed, Y_h_unsqueezed,
|
||||
Y_c_unsqueezed});
|
||||
Y_result = Y_forward;
|
||||
Y_h_result = Y_h_forward;
|
||||
Y_c_result = Y_c_forward;
|
||||
|
||||
// add bidrectional reverse layer
|
||||
// this is just flip X, lstm layer, flip results, stack
|
||||
// flip X
|
||||
Value dim0, X_reverse, Y_h_reverse, Y_c_reverse, Y_reverse_unflipped,
|
||||
Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list;
|
||||
LstmLayerOutput revLstmLayerOutput;
|
||||
if (isBidirectional) {
|
||||
dim0 = b.create<PrimListConstructOp>(b.getType<ListType>(intType),
|
||||
SmallVector<Value>{cstZero});
|
||||
X_reverse = b.create<AtenFlipOp>(xTy, X, dim0); // flip along seq_len dim
|
||||
revLstmLayerOutput =
|
||||
lstm_layer(b, X_reverse, initial_h_reverse, initial_c_reverse,
|
||||
weightsRev, activationsRev);
|
||||
|
||||
// unsqueeze Y_rev, Y_h_rev, Y_c_rev
|
||||
Y_h_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||
revLstmLayerOutput.Y_h, cstZero);
|
||||
Y_c_reverse = b.create<AtenUnsqueezeOp>(Y_h_Y_c_uni_type,
|
||||
revLstmLayerOutput.Y_c, cstZero);
|
||||
Y_reverse_unflipped =
|
||||
b.create<AtenUnsqueezeOp>(Y_uni_type, revLstmLayerOutput.Y, cstOne);
|
||||
|
||||
// flip Y_rev on dim 0 [seq_len]
|
||||
Y_reverse = b.create<AtenFlipOp>(Y_uni_type, Y_reverse_unflipped, dim0);
|
||||
|
||||
// Concat forward and reverse results on dim 1
|
||||
Y_output_list =
|
||||
b.create<PrimListConstructOp>(b.getType<ListType>(Y_uni_type),
|
||||
SmallVector<Value>{Y_forward, Y_reverse});
|
||||
Y_result = b.create<AtenCatOp>(Y_res_type, Y_output_list, cstOne);
|
||||
|
||||
// Concat forward and reverse results on dim 0
|
||||
Y_h_output_list = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||
SmallVector<Value>{Y_h_forward, Y_h_reverse});
|
||||
Y_h_result =
|
||||
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_h_output_list, cstZero);
|
||||
|
||||
Y_c_output_list = b.create<PrimListConstructOp>(
|
||||
b.getType<ListType>(Y_h_Y_c_uni_type),
|
||||
SmallVector<Value>{Y_c_forward, Y_c_reverse});
|
||||
Y_c_result =
|
||||
b.create<AtenCatOp>(Y_h_Y_c_res_type, Y_c_output_list, cstZero);
|
||||
}
|
||||
|
||||
if (layout == 1) {
|
||||
// Update Y, Y_h, Y_c results to layout 1
|
||||
Y_result = StaticTranspose(b, Y_result, 1, 2);
|
||||
Y_result = StaticTranspose(b, Y_result, 0, 1);
|
||||
Y_h_result = StaticTranspose(b, Y_h_result, 0, 1);
|
||||
Y_c_result = StaticTranspose(b, Y_c_result, 0, 1);
|
||||
}
|
||||
|
||||
// Only add outputs specified in onnx output node
|
||||
SmallVector<Value> actualOutputs = {Y_result, Y_h_result, Y_c_result},
|
||||
outputs;
|
||||
ValueTensorType resTy;
|
||||
for (int i = 0; i < binder.getNumResults(); ++i) {
|
||||
if (!binder.tensorResultTypeAtIndex(resTy, i) && !resTy) {
|
||||
outputs.push_back(cstNone);
|
||||
} else {
|
||||
outputs.push_back(actualOutputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(binder.op, outputs);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -16,10 +16,71 @@
|
|||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK-DAG: }
|
||||
// CHECK: }
|
||||
module {
|
||||
|
||||
func.func @test_lstm_basic(%arg0: !torch.vtensor<[15,2,4],f32>, %arg1: !torch.vtensor<[1,12,4],f32>, %arg2: !torch.vtensor<[1,12,3],f32>, %arg3: !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.hidden_size = 3 : si64} : (!torch.vtensor<[15,2,4],f32>, !torch.vtensor<[1,12,4],f32>, !torch.vtensor<[1,12,3],f32>, !torch.vtensor<[1,24],f32>) -> (!torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>)
|
||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[15,1,2,3],f32>, !torch.vtensor<[1,2,3],f32>, !torch.vtensor<[1,2,3],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lstm_bidirectional_with_initial_bias(
|
||||
// CHECK-SAME: %[[X:.*]]: !torch.vtensor<[32,32,192],f32>,
|
||||
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[2,192,192],f32>,
|
||||
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[2,192,48],f32>,
|
||||
// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[2,384],f32>)
|
||||
// CHECK: %[[FORWARD_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP_FWD:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y_FWD:.*]], %[[INITIAL_H_FWD:.*]], %[[INITIAL_C_FWD:.*]]) {
|
||||
// CHECK: ^bb0(%[[FORWARD_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_FWD:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_FWD:.*]]: !torch.vtensor<[32,48],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK: torch.aten.flip
|
||||
// CHECK: %[[REVERSE_LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIPS_REV:.*]], %[[LOOP_COND_REV:.*]], init(%[[Y_REV:.*]], %[[INITIAL_H_REV:.*]], %[[INITIAL_C_REV:.*]]) {
|
||||
// CHECK: ^bb0(%[[REVERSE_LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV_REV:.*]]: !torch.vtensor<[32,32,48],f32>, %[[H_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>, %[[C_PREV_REV:.*]]: !torch.vtensor<[32,48],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK: torch.aten.flip
|
||||
// CHECK: return %[[Y:.*]], %[[Y_H:.*]], %[[Y_C:.*]] : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @test_lstm_bidirectional_with_initial_bias(%arg0: !torch.vtensor<[32,32,192],f32>, %arg1: !torch.vtensor<[2,192,192],f32>, %arg2: !torch.vtensor<[2,192,48],f32>, %arg3: !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>) attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:3 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2, %arg3) {torch.onnx.direction = "bidirectional", torch.onnx.hidden_size = 48 : si64, torch.onnx.layout = 0 : si64} : (!torch.vtensor<[32,32,192],f32>, !torch.vtensor<[2,192,192],f32>, !torch.vtensor<[2,192,48],f32>, !torch.vtensor<[2,384],f32>) -> (!torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>)
|
||||
return %0#0, %0#1, %0#2 : !torch.vtensor<[32,2,32,48],f32>, !torch.vtensor<[2,32,48],f32>, !torch.vtensor<[2,32,48],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_lstm_batchwise_two_outputs(
|
||||
// CHECK-SAME: %[[X_LAYOUT_1:.*]]: !torch.vtensor<[3,1,2],f32>,
|
||||
// CHECK-SAME: %[[W:.*]]: !torch.vtensor<[1,28,2],f32>,
|
||||
// CHECK-SAME: %[[R:.*]]: !torch.vtensor<[1,28,7],f32>)
|
||||
// CHECK: torch.aten.transpose.int
|
||||
// CHECK: %[[LOOP_RES:.*]]:3 = torch.prim.Loop %[[MAX_TRIP:.*]], %[[LOOP_COND_FWD:.*]], init(%[[Y:.*]], %[[INITIAL_H:.*]], %[[INITIAL_C:.*]]) {
|
||||
// CHECK: ^bb0(%[[LOOP_INDEX:.*]]: !torch.int, %[[Y_PREV:.*]]: !torch.vtensor<[1,3,7],f32>, %[[H_PREV:.*]]: !torch.vtensor<[3,7],f32>, %[[C_PREV:.*]]: !torch.vtensor<[3,7],f32>):
|
||||
// CHECK-DAG: torch.aten.select.int
|
||||
// CHECK-DAG: torch.aten.linear
|
||||
// CHECK-DAG: torch.aten.sigmoid
|
||||
// CHECK-DAG: torch.aten.tanh
|
||||
// CHECK-DAG: torch.prim.Loop.condition
|
||||
// CHECK: }
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK-DAG: torch.aten.transpose.int
|
||||
// CHECK: return %[[Y:.*]], %[[Y_H:.*]] : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||
// CHECK: }
|
||||
|
||||
func.func @test_lstm_batchwise_two_outputs(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,28,2],f32>, %arg2: !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%none = torch.constant.none
|
||||
%0:2 = torch.operator "onnx.LSTM"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 7 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,28,2],f32>, !torch.vtensor<[1,28,7],f32>) -> (!torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>)
|
||||
return %0#0, %0#1 : !torch.vtensor<[3,1,1,7],f32>, !torch.vtensor<[3,1,7],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue