[x265] [PATCH] SSIM based RDO for mode selection

ashok at multicorewareinc.com ashok at multicorewareinc.com
Wed Dec 28 16:11:02 CET 2016


# HG changeset patch
# User Ashok Kumar Mishra <ashok at multicorewareinc.com>
# Date 1482932522 -19800
#      Wed Dec 28 19:12:02 2016 +0530
# Node ID 146036b4049c7d5abae3bae83f77d573b67f167e
# Parent  af10eaeb36cd22c7ad20ed2dafeac6f8e388ed9d
SSIM based RDO for mode selection

diff -r af10eaeb36cd -r 146036b4049c source/common/cudata.cpp
--- a/source/common/cudata.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/cudata.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -224,6 +224,7 @@
         m_trCoeff[0] = dataPool.trCoeffMemBlock + instance * (cuSize * cuSize);
         m_trCoeff[1] = m_trCoeff[2] = 0;
         m_transformSkip[1] = m_transformSkip[2] = m_cbf[1] = m_cbf[2] = 0;
+        m_fAc_den[0] = m_fDc_den[0] = 0;
     }
     else
     {
@@ -267,6 +268,8 @@
         m_trCoeff[0] = dataPool.trCoeffMemBlock + instance * (sizeL + sizeC * 2);
         m_trCoeff[1] = m_trCoeff[0] + sizeL;
         m_trCoeff[2] = m_trCoeff[0] + sizeL + sizeC;
+        for (int i = 0; i < 3; i++)
+            m_fAc_den[i] = m_fDc_den[i] = 0;
     }
 }
 
@@ -327,6 +330,11 @@
     m_bFirstRowInSlice = ctu.m_bFirstRowInSlice;
     m_bLastRowInSlice = ctu.m_bLastRowInSlice;
     m_bLastCuInSlice = ctu.m_bLastCuInSlice;
+    for (int i = 0; i < 3; i++)
+    {
+        m_fAc_den[i] = ctu.m_fAc_den[i];
+        m_fDc_den[i] = ctu.m_fDc_den[i];
+    }
 
     X265_CHECK(m_numPartitions == cuGeom.numPartitions, "initSubCU() size mismatch\n");
 
diff -r af10eaeb36cd -r 146036b4049c source/common/cudata.h
--- a/source/common/cudata.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/cudata.h	Wed Dec 28 19:12:02 2016 +0530
@@ -218,6 +218,8 @@
     const CUData* m_cuAbove;          // pointer to above neighbor CTU
     const CUData* m_cuLeft;           // pointer to left neighbor CTU
     double m_meanQP;
+    uint64_t      m_fAc_den[3];
+    uint64_t      m_fDc_den[3];
 
     CUData();
 
diff -r af10eaeb36cd -r 146036b4049c source/common/framedata.h
--- a/source/common/framedata.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/framedata.h	Wed Dec 28 19:12:02 2016 +0530
@@ -55,6 +55,7 @@
     double      avgLumaDistortion;
     double      avgChromaDistortion;
     double      avgPsyEnergy;
+    double      avgSsimEnergy;
     double      avgResEnergy;
     double      percentIntraNxN;
     double      percentSkipCu[NUM_CU_DEPTH];
@@ -68,6 +69,7 @@
     uint64_t    lumaDistortion;
     uint64_t    chromaDistortion;
     uint64_t    psyEnergy;
+    int64_t     ssimEnergy;
     uint64_t    resEnergy;
     uint64_t    cntSkipCu[NUM_CU_DEPTH];
     uint64_t    cntMergeCu[NUM_CU_DEPTH];
diff -r af10eaeb36cd -r 146036b4049c source/common/param.cpp
--- a/source/common/param.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/param.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -202,6 +202,7 @@
     param->bEnableTemporalSubLayers = 0;
     param->bEnableRdRefine = 0;
     param->bMultiPassOptRPS = 0;
+    param->bSsimRd = 0;
 
     /* Rate control options */
     param->rc.vbvMaxBitrate = 0;
@@ -926,6 +927,16 @@
         OPT("opt-cu-delta-qp") p->bOptCUDeltaQP = atobool(value);
         OPT("multi-pass-opt-analysis") p->analysisMultiPassRefine = atobool(value);
         OPT("multi-pass-opt-distortion") p->analysisMultiPassDistortion = atobool(value);
+        OPT("ssim-rd")
+        {
+            int bval = atobool(value);
+            if (bError || bval)
+            {
+                bError = false;
+                p->psyRd = 0.0;
+                p->bSsimRd = atobool(value);
+            }
+        }
         else
             return X265_PARAM_BAD_NAME;
     }
diff -r af10eaeb36cd -r 146036b4049c source/common/quant.cpp
--- a/source/common/quant.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/quant.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -479,6 +479,82 @@
     }
 }
 
+uint64_t Quant::ssimDistortion(const CUData& cu, const pixel* fenc, uint32_t fStride, const pixel* recon, intptr_t rstride, uint32_t log2TrSize, TextType ttype, uint32_t absPartIdx)
+{
+    static const int ssim_c1 = (int)(.01 * .01 * PIXEL_MAX * PIXEL_MAX * 64 + .5); // 416
+    static const int ssim_c2 = (int)(.03 * .03 * PIXEL_MAX * PIXEL_MAX * 64 * 63 + .5); // 235963
+
+    int trSize = 1 << log2TrSize;
+    uint64_t ssDc = 0, ssBlock = 0, ssAc = 0;
+
+    // Calculation of (X(0) - Y(0)) * (X(0) - Y(0)), DC
+    ssDc = 0;
+    for (int y = 0; y < trSize; y += 4)
+    {
+        for (int x = 0; x < trSize; x += 4)
+        {
+            int temp = fenc[y * fStride + x] - recon[y * rstride + x]; // copy of residual coeff
+            ssDc += temp * temp;
+        }
+    }
+
+    // Calculation of (X(k) - Y(k)) * (X(k) - Y(k)), AC
+    ssBlock = 0;
+    for (int y = 0; y < trSize; y++)
+    {
+        for (int x = 0; x < trSize; x++)
+        {
+            int temp = fenc[y * fStride + x] - recon[y * rstride + x]; // copy of residual coeff
+            ssBlock += temp * temp;
+        }
+    }
+
+    ssAc = ssBlock - ssDc;
+
+    // 1. Calculation of fdc'
+    // Calculate numerator of dc normalization factor
+    uint64_t fDc_num = 0;
+
+    // 2. Calculate dc component
+    uint64_t dc_k = 0;
+    for (int block_yy = 0; block_yy < trSize; block_yy += 4)
+    {
+        for (int block_xx = 0; block_xx < trSize; block_xx += 4)
+        {
+            uint32_t temp = fenc[block_yy * fStride + block_xx];
+            dc_k += temp * temp;
+        }
+    }
+
+    fDc_num = (2 * dc_k)  + (trSize * trSize * ssim_c1); // 16 pixels -> for each 4x4 block
+    fDc_num /= ((trSize >> 2) * (trSize >> 2));
+
+    // 1. Calculation of fac'
+    // Calculate numerator of ac normalization factor
+    uint64_t fAc_num = 0;
+
+    // 2. Calculate ac component
+    uint64_t ac_k = 0;
+    for (int block_yy = 0; block_yy < trSize; block_yy += 1)
+    {
+        for (int block_xx = 0; block_xx < trSize; block_xx += 1)
+        {
+            uint32_t temp = fenc[block_yy * fStride + block_xx];
+            ac_k += temp * temp;
+        }
+    }
+    ac_k -= dc_k;
+
+    double s = 1 + 0.005 * cu.m_qp[absPartIdx];
+
+    fAc_num = ac_k + uint64_t(s * ac_k) + ssim_c2;
+    fAc_num /= ((trSize >> 2) * (trSize >> 2));
+
+    // Calculate dc and ac normalization factor
+    uint64_t ssim_distortion = ((ssDc * cu.m_fDc_den[ttype]) / fDc_num) + ((ssAc * cu.m_fAc_den[ttype]) / fAc_num);
+    return ssim_distortion;
+}
+
 void Quant::invtransformNxN(const CUData& cu, int16_t* residual, uint32_t resiStride, const coeff_t* coeff,
                             uint32_t log2TrSize, TextType ttype, bool bIntra, bool useTransformSkip, uint32_t numSig)
 {
diff -r af10eaeb36cd -r 146036b4049c source/common/quant.h
--- a/source/common/quant.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/common/quant.h	Wed Dec 28 19:12:02 2016 +0530
@@ -111,6 +111,8 @@
 
     void invtransformNxN(const CUData& cu, int16_t* residual, uint32_t resiStride, const coeff_t* coeff,
                          uint32_t log2TrSize, TextType ttype, bool bIntra, bool useTransformSkip, uint32_t numSig);
+    uint64_t ssimDistortion(const CUData& cu, const pixel* fenc, uint32_t fStride, const pixel* recon, intptr_t rstride,
+                            uint32_t log2TrSize, TextType ttype, uint32_t absPartIdx);
 
     /* Pattern decision for context derivation process of significant_coeff_flag */
     static uint32_t calcPatternSigCtx(uint64_t sigCoeffGroupFlag64, uint32_t cgPosX, uint32_t cgPosY, uint32_t cgBlkPos, uint32_t trSizeCG)
diff -r af10eaeb36cd -r 146036b4049c source/encoder/analysis.cpp
--- a/source/encoder/analysis.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/analysis.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -76,6 +76,7 @@
     m_reuseRef = NULL;
     m_bHD = false;
 }
+
 bool Analysis::create(ThreadLocalData *tld)
 {
     m_tld = tld;
@@ -145,6 +146,9 @@
     ctu.m_meanQP = initialContext.m_meanQP;
     m_modeDepth[0].fencYuv.copyFromPicYuv(*m_frame->m_fencPic, ctu.m_cuAddr, 0);
 
+    if (m_param->bSsimRd)
+        calculateNormFactor(ctu, qp);
+
     uint32_t numPartition = ctu.m_numPartitions;
     if (m_param->analysisMultiPassRefine && m_param->rc.bStatRead)
     {
@@ -2910,3 +2914,65 @@
 
     return x265_clip3(m_param->rc.qpMin, m_param->rc.qpMax, (int)(qp + 0.5));
 }
+
+void Analysis::normFactor(const pixel* src, uint32_t blockSize, CUData& ctu, int qp, TextType ttype)
+{
+    static const int ssim_c1 = (int)(.01 * .01 * PIXEL_MAX * PIXEL_MAX * 64 + .5); // 416
+    static const int ssim_c2 = (int)(.03 * .03 * PIXEL_MAX * PIXEL_MAX * 64 * 63 + .5); // 235963
+
+    double s = 1 + 0.005 * qp;
+
+    // Calculate denominator of normalization factor
+    uint64_t fDc_den = 0, fAc_den = 0;
+
+    // 1. Calculate dc component
+    uint64_t z_o = 0;
+    for (uint32_t block_yy = 0; block_yy < blockSize; block_yy += 4)
+    {
+        for (uint32_t block_xx = 0; block_xx < blockSize; block_xx += 4)
+        {
+            uint32_t temp = src[block_yy * blockSize + block_xx];
+            z_o += temp * temp; // 2 * (Z(0)) pow(2)
+        }
+    }
+    fDc_den = (2 * z_o)  + (blockSize * blockSize * ssim_c1); // 2 * (Z(0)) pow(2) + N * C1
+    fDc_den /= ((blockSize >> 2) * (blockSize >> 2));
+
+    // 2. Calculate ac component
+    uint64_t z_k = 0;
+    for (uint32_t block_yy = 0; block_yy < blockSize; block_yy += 1)
+    {
+        for (uint32_t block_xx = 0; block_xx < blockSize; block_xx += 1)
+        {
+            uint32_t temp = src[block_yy * blockSize + block_xx];
+            z_k += temp * temp;
+        }
+    }
+
+    // Remove the DC part
+    z_k -= z_o;
+
+    fAc_den = z_k + int(s * z_k) + ssim_c2;
+    fAc_den /= ((blockSize >> 2) * (blockSize >> 2));
+
+    ctu.m_fAc_den[ttype] = fAc_den;
+    ctu.m_fDc_den[ttype] = fDc_den;
+}
+
+void Analysis::calculateNormFactor(CUData& ctu, int qp)
+{
+    const pixel* srcY = m_modeDepth[0].fencYuv.m_buf[0];
+    uint32_t blockSize = m_modeDepth[0].fencYuv.m_size;
+
+    normFactor(srcY, blockSize, ctu, qp, TEXT_LUMA);
+
+    if (m_csp != X265_CSP_I400 && m_frame->m_fencPic->m_picCsp != X265_CSP_I400)
+    {
+        const pixel* srcU = m_modeDepth[0].fencYuv.m_buf[1];
+        const pixel* srcV = m_modeDepth[0].fencYuv.m_buf[2];
+        uint32_t blockSizeC = m_modeDepth[0].fencYuv.m_csize;
+
+        normFactor(srcU, blockSizeC, ctu, qp, TEXT_CHROMA_U);
+        normFactor(srcV, blockSizeC, ctu, qp, TEXT_CHROMA_V);
+    }
+}
diff -r af10eaeb36cd -r 146036b4049c source/encoder/analysis.h
--- a/source/encoder/analysis.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/analysis.h	Wed Dec 28 19:12:02 2016 +0530
@@ -176,6 +176,8 @@
 
     int calculateQpforCuSize(const CUData& ctu, const CUGeom& cuGeom, double baseQP = -1);
 
+    void calculateNormFactor(CUData& ctu, int qp);
+    void normFactor(const pixel* src, uint32_t blockSize, CUData& ctu, int qp, TextType ttype);
     /* check whether current mode is the new best */
     inline void checkBestMode(Mode& mode, uint32_t depth)
     {
diff -r af10eaeb36cd -r 146036b4049c source/encoder/frameencoder.cpp
--- a/source/encoder/frameencoder.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/frameencoder.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -827,6 +827,7 @@
         m_frame->m_encData->m_frameStats.lumaDistortion   += m_rows[i].rowStats.lumaDistortion;
         m_frame->m_encData->m_frameStats.chromaDistortion += m_rows[i].rowStats.chromaDistortion;
         m_frame->m_encData->m_frameStats.psyEnergy        += m_rows[i].rowStats.psyEnergy;
+        m_frame->m_encData->m_frameStats.ssimEnergy       += m_rows[i].rowStats.ssimEnergy;
         m_frame->m_encData->m_frameStats.resEnergy        += m_rows[i].rowStats.resEnergy;
         for (uint32_t depth = 0; depth <= g_maxCUDepth; depth++)
         {
@@ -841,6 +842,7 @@
     m_frame->m_encData->m_frameStats.avgLumaDistortion   = (double)(m_frame->m_encData->m_frameStats.lumaDistortion) / m_frame->m_encData->m_frameStats.totalCtu;
     m_frame->m_encData->m_frameStats.avgChromaDistortion = (double)(m_frame->m_encData->m_frameStats.chromaDistortion) / m_frame->m_encData->m_frameStats.totalCtu;
     m_frame->m_encData->m_frameStats.avgPsyEnergy        = (double)(m_frame->m_encData->m_frameStats.psyEnergy) / m_frame->m_encData->m_frameStats.totalCtu;
+    m_frame->m_encData->m_frameStats.avgSsimEnergy       = (double)(m_frame->m_encData->m_frameStats.ssimEnergy) / m_frame->m_encData->m_frameStats.totalCtu;
     m_frame->m_encData->m_frameStats.avgResEnergy        = (double)(m_frame->m_encData->m_frameStats.resEnergy) / m_frame->m_encData->m_frameStats.totalCtu;
     m_frame->m_encData->m_frameStats.percentIntraNxN     = (double)(m_frame->m_encData->m_frameStats.cntIntraNxN * 100) / m_frame->m_encData->m_frameStats.totalCu;
     for (uint32_t depth = 0; depth <= g_maxCUDepth; depth++)
@@ -1419,6 +1421,7 @@
         curRow.rowStats.lumaDistortion   += best.lumaDistortion;
         curRow.rowStats.chromaDistortion += best.chromaDistortion;
         curRow.rowStats.psyEnergy        += best.psyEnergy;
+        curRow.rowStats.ssimEnergy       += best.ssimEnergy;
         curRow.rowStats.resEnergy        += best.resEnergy;
         curRow.rowStats.cntIntraNxN      += frameLog.cntIntraNxN;
         curRow.rowStats.totalCu          += frameLog.totalCu;
diff -r af10eaeb36cd -r 146036b4049c source/encoder/rdcost.h
--- a/source/encoder/rdcost.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/rdcost.h	Wed Dec 28 19:12:02 2016 +0530
@@ -41,9 +41,11 @@
     uint32_t  m_chromaDistWeight[2];
     uint32_t  m_psyRdBase;
     uint32_t  m_psyRd;
+    uint32_t  m_ssimRd;
     int       m_qp; /* QP used to configure lambda, may be higher than QP_MAX_SPEC but <= QP_MAX_MAX */
 
     void setPsyRdScale(double scale)                { m_psyRdBase = (uint32_t)floor(65536.0 * scale * 0.33); }
+    void setSsimRd(int ssimRd) { m_ssimRd = ssimRd; };
 
     void setQP(const Slice& slice, int qp)
     {
@@ -129,6 +131,20 @@
         return distortion + ((m_lambda * m_psyRd * psycost) >> 24) + ((bits * m_lambda2) >> 8);
     }
 
+    inline uint64_t calcSsimRdCost(uint64_t distortion, uint32_t bits, uint32_t ssimCost) const
+    {
+#if X265_DEPTH < 10
+        X265_CHECK((bits <= (UINT64_MAX / m_lambda2)) && (ssimCost <= UINT64_MAX / m_lambda),
+                   "calcPsyRdCost wrap detected dist: %u, bits: %u, lambda: " X265_LL ", lambda2: " X265_LL "\n",
+                   distortion, bits, m_lambda, m_lambda2);
+#else
+        X265_CHECK((bits <= (UINT64_MAX / m_lambda2)) && (ssimCost <= UINT64_MAX / m_lambda),
+                   "calcPsyRdCost wrap detected dist: " X265_LL ", bits: %u, lambda: " X265_LL ", lambda2: " X265_LL "\n",
+                   distortion, bits, m_lambda, m_lambda2);
+#endif
+        return distortion + ((m_lambda * ssimCost) >> 14) + ((bits * m_lambda2) >> 8);
+    }
+
     inline uint64_t calcRdSADCost(uint32_t sadCost, uint32_t bits) const
     {
         X265_CHECK(bits <= (UINT64_MAX - 128) / m_lambda,
diff -r af10eaeb36cd -r 146036b4049c source/encoder/search.cpp
--- a/source/encoder/search.cpp	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/search.cpp	Wed Dec 28 19:12:02 2016 +0530
@@ -78,6 +78,7 @@
     m_numLayers = g_log2Size[param.maxCUSize] - 2;
 
     m_rdCost.setPsyRdScale(param.psyRd);
+    m_rdCost.setSsimRd(param.bSsimRd);
     m_me.init(param.internalCsp);
 
     bool ok = m_quant.init(param.psyRdoq, scalingList, m_entropyCoder);
@@ -417,6 +418,11 @@
             fullCost.energy = m_rdCost.psyCost(sizeIdx, fenc, mode.fencYuv->m_size, reconQt, reconQtStride);
             fullCost.rdcost = m_rdCost.calcPsyRdCost(fullCost.distortion, fullCost.bits, fullCost.energy);
         }
+        else if(m_rdCost.m_ssimRd)
+        {
+            fullCost.energy = m_quant.ssimDistortion(cu, fenc, stride, reconQt, reconQtStride, log2TrSize, TEXT_LUMA, absPartIdx);
+            fullCost.rdcost = m_rdCost.calcSsimRdCost(fullCost.distortion, fullCost.bits, fullCost.energy);
+        }
         else
             fullCost.rdcost = m_rdCost.calcRdCost(fullCost.distortion, fullCost.bits);
     }
@@ -460,6 +466,8 @@
 
             if (m_rdCost.m_psyRd)
                 splitCost.rdcost = m_rdCost.calcPsyRdCost(splitCost.distortion, splitCost.bits, splitCost.energy);
+            else if(m_rdCost.m_ssimRd)
+                splitCost.rdcost = m_rdCost.calcSsimRdCost(splitCost.distortion, splitCost.bits, splitCost.energy);
             else
                 splitCost.rdcost = m_rdCost.calcRdCost(splitCost.distortion, splitCost.bits);
         }
@@ -625,6 +633,11 @@
             tmpEnergy = m_rdCost.psyCost(sizeIdx, fenc, fencYuv->m_size, tmpRecon, tmpReconStride);
             tmpCost = m_rdCost.calcPsyRdCost(tmpDist, tmpBits, tmpEnergy);
         }
+        else if(m_rdCost.m_ssimRd)
+        {
+            tmpEnergy = m_quant.ssimDistortion(cu, fenc, stride, reconQt, reconQtStride, log2TrSize, TEXT_LUMA, absPartIdx);
+            tmpCost = m_rdCost.calcSsimRdCost(tmpDist, tmpBits, tmpEnergy);
+        }
         else
             tmpCost = m_rdCost.calcRdCost(tmpDist, tmpBits);
 
@@ -899,6 +912,8 @@
 
             if (m_rdCost.m_psyRd)
                 outCost.energy += m_rdCost.psyCost(sizeIdxC, fenc, stride, reconQt, reconQtStride);
+            else if(m_rdCost.m_ssimRd)
+                outCost.energy += m_quant.ssimDistortion(cu, fenc, stride, reconQt, reconQtStride, log2TrSizeC, ttype, absPartIdxC);
 
             primitives.cu[sizeIdxC].copy_pp(picReconC, picStride, reconQt, reconQtStride);
         }
@@ -1016,6 +1031,11 @@
                     tmpEnergy = m_rdCost.psyCost(sizeIdxC, fenc, stride, reconQt, reconQtStride);
                     tmpCost = m_rdCost.calcPsyRdCost(tmpDist, tmpBits, tmpEnergy);
                 }
+                else if(m_rdCost.m_ssimRd)
+                {
+                    tmpEnergy = m_quant.ssimDistortion(cu, fenc, stride, reconQt, reconQtStride, log2TrSizeC, ttype, absPartIdxC);
+                    tmpCost = m_rdCost.calcSsimRdCost(tmpDist, tmpBits, tmpEnergy);
+                }
                 else
                     tmpCost = m_rdCost.calcRdCost(tmpDist, tmpBits);
 
@@ -1229,11 +1249,12 @@
     m_entropyCoder.store(intraMode.contexts);
     intraMode.totalBits = m_entropyCoder.getNumberOfWrittenBits();
     intraMode.coeffBits = intraMode.totalBits - intraMode.mvBits - skipFlagBits;
+    const Yuv* fencYuv = intraMode.fencYuv;
     if (m_rdCost.m_psyRd)
-    {
-        const Yuv* fencYuv = intraMode.fencYuv;
         intraMode.psyEnergy = m_rdCost.psyCost(cuGeom.log2CUSize - 2, fencYuv->m_buf[0], fencYuv->m_size, intraMode.reconYuv.m_buf[0], intraMode.reconYuv.m_size);
-    }
+    else if(m_rdCost.m_ssimRd)
+        intraMode.ssimEnergy = m_quant.ssimDistortion(cu, fencYuv->m_buf[0], fencYuv->m_size, intraMode.reconYuv.m_buf[0], intraMode.reconYuv.m_size, cuGeom.log2CUSize, TEXT_LUMA, 0);
+
     intraMode.resEnergy = primitives.cu[cuGeom.log2CUSize - 2].sse_pp(intraMode.fencYuv->m_buf[0], intraMode.fencYuv->m_size, intraMode.predYuv.m_buf[0], intraMode.predYuv.m_size);
 
     updateModeCost(intraMode);
@@ -1448,12 +1469,13 @@
 
     intraMode.totalBits = m_entropyCoder.getNumberOfWrittenBits();
     intraMode.coeffBits = intraMode.totalBits - intraMode.mvBits - skipFlagBits;
+    const Yuv* fencYuv = intraMode.fencYuv;
     if (m_rdCost.m_psyRd)
-    {
-        const Yuv* fencYuv = intraMode.fencYuv;
         intraMode.psyEnergy = m_rdCost.psyCost(cuGeom.log2CUSize - 2, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size);
-    }
-    intraMode.resEnergy = primitives.cu[cuGeom.log2CUSize - 2].sse_pp(intraMode.fencYuv->m_buf[0], intraMode.fencYuv->m_size, intraMode.predYuv.m_buf[0], intraMode.predYuv.m_size);
+    else if(m_rdCost.m_ssimRd)
+        intraMode.ssimEnergy = m_quant.ssimDistortion(cu, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size, cuGeom.log2CUSize, TEXT_LUMA, 0);
+
+    intraMode.resEnergy = primitives.cu[cuGeom.log2CUSize - 2].sse_pp(fencYuv->m_buf[0], fencYuv->m_size, intraMode.predYuv.m_buf[0], intraMode.predYuv.m_size);
     m_entropyCoder.store(intraMode.contexts);
     updateModeCost(intraMode);
     checkDQP(intraMode, cuGeom);
@@ -1778,7 +1800,7 @@
             codeCoeffQTChroma(cu, initTuDepth, absPartIdxC, TEXT_CHROMA_U);
             codeCoeffQTChroma(cu, initTuDepth, absPartIdxC, TEXT_CHROMA_V);
             uint32_t bits = m_entropyCoder.getNumberOfWrittenBits();
-            uint64_t cost = m_rdCost.m_psyRd ? m_rdCost.calcPsyRdCost(outCost.distortion, bits, outCost.energy)
+            uint64_t cost = m_rdCost.m_psyRd ? m_rdCost.calcPsyRdCost(outCost.distortion, bits, outCost.energy) : m_rdCost.m_ssimRd ? m_rdCost.calcSsimRdCost(outCost.distortion, bits, outCost.energy)
                                              : m_rdCost.calcRdCost(outCost.distortion, bits);
 
             if (cost < bestCost)
@@ -2637,6 +2659,9 @@
     interMode.totalBits = interMode.mvBits + skipFlagBits;
     if (m_rdCost.m_psyRd)
         interMode.psyEnergy = m_rdCost.psyCost(part, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size);
+    else if(m_rdCost.m_ssimRd)
+        interMode.ssimEnergy = m_quant.ssimDistortion(cu, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size, cu.m_log2CUSize[0], TEXT_LUMA, 0);
+
     interMode.resEnergy = primitives.cu[part].sse_pp(fencYuv->m_buf[0], fencYuv->m_size, predYuv->m_buf[0], predYuv->m_size);
     updateModeCost(interMode);
     m_entropyCoder.store(interMode.contexts);
@@ -2707,13 +2732,17 @@
         m_entropyCoder.codeQtRootCbfZero();
         uint32_t cbf0Bits = m_entropyCoder.getNumberOfWrittenBits();
 
-        uint64_t cbf0Cost;
-        uint32_t cbf0Energy;
+        uint32_t cbf0Energy; uint64_t cbf0Cost;
         if (m_rdCost.m_psyRd)
         {
             cbf0Energy = m_rdCost.psyCost(log2CUSize - 2, fencYuv->m_buf[0], fencYuv->m_size, predYuv->m_buf[0], predYuv->m_size);
             cbf0Cost = m_rdCost.calcPsyRdCost(cbf0Dist, cbf0Bits, cbf0Energy);
         }
+        else if(m_rdCost.m_ssimRd)
+        {
+            cbf0Energy = m_quant.ssimDistortion(cu, fencYuv->m_buf[0], fencYuv->m_size, predYuv->m_buf[0], predYuv->m_size, log2CUSize, TEXT_LUMA, 0);
+            cbf0Cost = m_rdCost.calcSsimRdCost(cbf0Dist, cbf0Bits, cbf0Energy);
+        }
         else
             cbf0Cost = m_rdCost.calcRdCost(cbf0Dist, cbf0Bits);
 
@@ -2782,6 +2811,9 @@
     }
     if (m_rdCost.m_psyRd)
         interMode.psyEnergy = m_rdCost.psyCost(sizeIdx, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size);
+    else if(m_rdCost.m_ssimRd)
+        interMode.ssimEnergy = m_quant.ssimDistortion(cu, fencYuv->m_buf[0], fencYuv->m_size, reconYuv->m_buf[0], reconYuv->m_size, cu.m_log2CUSize[0], TEXT_LUMA, 0);
+
     interMode.resEnergy = primitives.cu[sizeIdx].sse_pp(fencYuv->m_buf[0], fencYuv->m_size, predYuv->m_buf[0], predYuv->m_size);
     interMode.totalBits = bits;
     interMode.lumaDistortion = bestLumaDist;
@@ -2929,12 +2961,14 @@
     }
 }
 
-uint64_t Search::estimateNullCbfCost(sse_t dist, uint32_t psyEnergy, uint32_t tuDepth, TextType compId)
+uint64_t Search::estimateNullCbfCost(sse_t dist, uint32_t energy, uint32_t tuDepth, TextType compId)
 {
     uint32_t nullBits = m_entropyCoder.estimateCbfBits(0, compId, tuDepth);
 
     if (m_rdCost.m_psyRd)
-        return m_rdCost.calcPsyRdCost(dist, nullBits, psyEnergy);
+        return m_rdCost.calcPsyRdCost(dist, nullBits, energy);
+    else if(m_rdCost.m_ssimRd)
+        return m_rdCost.calcSsimRdCost(dist, nullBits, energy);
     else
         return m_rdCost.calcRdCost(dist, nullBits);
 }
@@ -2983,6 +3017,8 @@
 
     if (m_rdCost.m_psyRd)
         splitCost.rdcost = m_rdCost.calcPsyRdCost(splitCost.distortion, splitCost.bits, splitCost.energy);
+    else if(m_rdCost.m_ssimRd)
+        splitCost.rdcost = m_rdCost.calcSsimRdCost(splitCost.distortion, splitCost.bits, splitCost.energy);
     else
         splitCost.rdcost = m_rdCost.calcRdCost(splitCost.distortion, splitCost.bits);
         
@@ -3055,7 +3091,7 @@
     uint32_t numSig[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, {0, 0}, {0, 0} };
     uint32_t singleBits[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, { 0, 0 }, { 0, 0 } };
     sse_t singleDist[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, { 0, 0 }, { 0, 0 } };
-    uint32_t singlePsyEnergy[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, { 0, 0 }, { 0, 0 } };
+    uint32_t singleEnergy[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, { 0, 0 }, { 0, 0 } };
     uint32_t bestTransformMode[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { 0, 0 }, { 0, 0 }, { 0, 0 } };
     uint64_t minCost[MAX_NUM_COMPONENT][2 /*0 = top (or whole TU for non-4:2:2) sub-TU, 1 = bottom sub-TU*/] = { { MAX_INT64, MAX_INT64 }, {MAX_INT64, MAX_INT64}, {MAX_INT64, MAX_INT64} };
 
@@ -3104,9 +3140,11 @@
 
         //Assuming zero residual 
         sse_t zeroDistY = primitives.cu[partSize].sse_pp(fenc, fencYuv->m_size, mode.predYuv.getLumaAddr(absPartIdx), mode.predYuv.m_size);
-        uint32_t zeroPsyEnergyY = 0;
+        uint32_t zeroEnergyY = 0;
         if (m_rdCost.m_psyRd)
-            zeroPsyEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, mode.predYuv.getLumaAddr(absPartIdx), mode.predYuv.m_size);
+            zeroEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, mode.predYuv.getLumaAddr(absPartIdx), mode.predYuv.m_size);
+        else if(m_rdCost.m_ssimRd)
+            zeroEnergyY = m_quant.ssimDistortion(cu, fenc, fencYuv->m_size, mode.predYuv.getLumaAddr(absPartIdx), mode.predYuv.m_size, log2TrSize, TEXT_LUMA, absPartIdx);
 
         int16_t* curResiY = m_rqt[qtLayer].resiQtYuv.getLumaAddr(absPartIdx);
         uint32_t strideResiY = m_rqt[qtLayer].resiQtYuv.m_size;
@@ -3123,11 +3161,16 @@
 
             const sse_t nonZeroDistY = primitives.cu[partSize].sse_pp(fenc, fencYuv->m_size, curReconY, strideReconY);
             uint32_t nzCbfBitsY = m_entropyCoder.estimateCbfBits(cbfFlag[TEXT_LUMA][0], TEXT_LUMA, tuDepth);
-            uint32_t nonZeroPsyEnergyY = 0; uint64_t singleCostY = 0;
+            uint32_t nonZeroEnergyY = 0; uint64_t singleCostY = 0;
             if (m_rdCost.m_psyRd)
             {
-                nonZeroPsyEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, curReconY, strideReconY);
-                singleCostY = m_rdCost.calcPsyRdCost(nonZeroDistY, nzCbfBitsY + singleBits[TEXT_LUMA][0], nonZeroPsyEnergyY);
+                nonZeroEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, curReconY, strideReconY);
+                singleCostY = m_rdCost.calcPsyRdCost(nonZeroDistY, nzCbfBitsY + singleBits[TEXT_LUMA][0], nonZeroEnergyY);
+            }
+            else if(m_rdCost.m_ssimRd)
+            {
+                nonZeroEnergyY = m_quant.ssimDistortion(cu, fenc, fencYuv->m_size, curReconY, strideReconY, log2TrSize, TEXT_LUMA, absPartIdx);
+                singleCostY = m_rdCost.calcSsimRdCost(nonZeroDistY, nzCbfBitsY + singleBits[TEXT_LUMA][0], nonZeroEnergyY);
             }
             else
                 singleCostY = m_rdCost.calcRdCost(nonZeroDistY, nzCbfBitsY + singleBits[TEXT_LUMA][0]);
@@ -3135,14 +3178,14 @@
             if (cu.m_tqBypass[0])
             {
                 singleDist[TEXT_LUMA][0] = nonZeroDistY;
-                singlePsyEnergy[TEXT_LUMA][0] = nonZeroPsyEnergyY;
+                singleEnergy[TEXT_LUMA][0] = nonZeroEnergyY;
             }
             else
             {
                 // zero-cost calculation for luma. This is an approximation
                 // Initial cost calculation was also an approximation. First resetting the bit counter and then encoding zero cbf.
                 // Now encoding the zero cbf without writing into bitstream, keeping m_fracBits unchanged. The same is valid for chroma.
-                uint64_t nullCostY = estimateNullCbfCost(zeroDistY, zeroPsyEnergyY, tuDepth, TEXT_LUMA);
+                uint64_t nullCostY = estimateNullCbfCost(zeroDistY, zeroEnergyY, tuDepth, TEXT_LUMA);
 
                 if (nullCostY < singleCostY)
                 {
@@ -3156,25 +3199,25 @@
                     if (checkTransformSkipY)
                         minCost[TEXT_LUMA][0] = nullCostY;
                     singleDist[TEXT_LUMA][0] = zeroDistY;
-                    singlePsyEnergy[TEXT_LUMA][0] = zeroPsyEnergyY;
+                    singleEnergy[TEXT_LUMA][0] = zeroEnergyY;
                 }
                 else
                 {
                     if (checkTransformSkipY)
                         minCost[TEXT_LUMA][0] = singleCostY;
                     singleDist[TEXT_LUMA][0] = nonZeroDistY;
-                    singlePsyEnergy[TEXT_LUMA][0] = nonZeroPsyEnergyY;
+                    singleEnergy[TEXT_LUMA][0] = nonZeroEnergyY;
                 }
             }
         }
         else
         {
             if (checkTransformSkipY)
-                minCost[TEXT_LUMA][0] = estimateNullCbfCost(zeroDistY, zeroPsyEnergyY, tuDepth, TEXT_LUMA);
+                minCost[TEXT_LUMA][0] = estimateNullCbfCost(zeroDistY, zeroEnergyY, tuDepth, TEXT_LUMA);
             primitives.cu[partSize].blockfill_s(curResiY, strideResiY, 0);
             singleDist[TEXT_LUMA][0] = zeroDistY;
             singleBits[TEXT_LUMA][0] = 0;
-            singlePsyEnergy[TEXT_LUMA][0] = zeroPsyEnergyY;
+            singleEnergy[TEXT_LUMA][0] = zeroEnergyY;
         }
 
         cu.setCbfSubParts(cbfFlag[TEXT_LUMA][0] << tuDepth, TEXT_LUMA, absPartIdx, depth);
@@ -3186,7 +3229,7 @@
             for (uint32_t chromaId = TEXT_CHROMA_U; chromaId <= TEXT_CHROMA_V; chromaId++)
             {
                 sse_t zeroDistC = 0;
-                uint32_t zeroPsyEnergyC = 0;
+                uint32_t zeroEnergyC = 0;
                 coeff_t* coeffCurC = m_rqt[qtLayer].coeffRQT[chromaId] + coeffOffsetC;
                 TURecurse tuIterator(splitIntoSubTUs ? VERTICAL_SPLIT : DONT_SPLIT, absPartIdxStep, absPartIdx);
 
@@ -3214,9 +3257,11 @@
                     int16_t* curResiC = m_rqt[qtLayer].resiQtYuv.getChromaAddr(chromaId, absPartIdxC);
                     zeroDistC = m_rdCost.scaleChromaDist(chromaId, primitives.cu[log2TrSizeC - 2].sse_pp(fenc, fencYuv->m_csize, mode.predYuv.getChromaAddr(chromaId, absPartIdxC), mode.predYuv.m_csize));
 
+                    // Assuming zero residual 
                     if (m_rdCost.m_psyRd)
-                    //Assuming zero residual 
-                        zeroPsyEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, mode.predYuv.getChromaAddr(chromaId, absPartIdxC), mode.predYuv.m_csize);
+                        zeroEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, mode.predYuv.getChromaAddr(chromaId, absPartIdxC), mode.predYuv.m_csize);
+                    else if(m_rdCost.m_ssimRd)
+                        zeroEnergyC = m_quant.ssimDistortion(cu, fenc, fencYuv->m_csize, mode.predYuv.getChromaAddr(chromaId, absPartIdxC), mode.predYuv.m_csize, log2TrSizeC, (TextType)chromaId, absPartIdxC);
 
                     if (cbfFlag[chromaId][tuIterator.section])
                     {
@@ -3230,11 +3275,16 @@
                         primitives.cu[partSizeC].add_ps(curReconC, strideReconC, mode.predYuv.getChromaAddr(chromaId, absPartIdxC), curResiC, mode.predYuv.m_csize, strideResiC);
                         sse_t nonZeroDistC = m_rdCost.scaleChromaDist(chromaId, primitives.cu[partSizeC].sse_pp(fenc, fencYuv->m_csize, curReconC, strideReconC));
                         uint32_t nzCbfBitsC = m_entropyCoder.estimateCbfBits(cbfFlag[chromaId][tuIterator.section], (TextType)chromaId, tuDepth);
-                        uint32_t nonZeroPsyEnergyC = 0; uint64_t singleCostC = 0;
+                        uint32_t nonZeroEnergyC = 0; uint64_t singleCostC = 0;
                         if (m_rdCost.m_psyRd)
                         {
-                            nonZeroPsyEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, curReconC, strideReconC);
-                            singleCostC = m_rdCost.calcPsyRdCost(nonZeroDistC, nzCbfBitsC + singleBits[chromaId][tuIterator.section], nonZeroPsyEnergyC);
+                            nonZeroEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, curReconC, strideReconC);
+                            singleCostC = m_rdCost.calcPsyRdCost(nonZeroDistC, nzCbfBitsC + singleBits[chromaId][tuIterator.section], nonZeroEnergyC);
+                        }
+                        else if(m_rdCost.m_ssimRd)
+                        {
+                            nonZeroEnergyC = m_quant.ssimDistortion(cu, fenc, fencYuv->m_csize, curReconC, strideReconC, log2TrSizeC, (TextType)chromaId, absPartIdxC);
+                            singleCostC = m_rdCost.calcSsimRdCost(nonZeroDistC, nzCbfBitsC + singleBits[chromaId][tuIterator.section], nonZeroEnergyC);
                         }
                         else
                             singleCostC = m_rdCost.calcRdCost(nonZeroDistC, nzCbfBitsC + singleBits[chromaId][tuIterator.section]);
@@ -3242,12 +3292,12 @@
                         if (cu.m_tqBypass[0])
                         {
                             singleDist[chromaId][tuIterator.section] = nonZeroDistC;
-                            singlePsyEnergy[chromaId][tuIterator.section] = nonZeroPsyEnergyC;
+                            singleEnergy[chromaId][tuIterator.section] = nonZeroEnergyC;
                         }
                         else
                         {
                             //zero-cost calculation for chroma. This is an approximation
-                            uint64_t nullCostC = estimateNullCbfCost(zeroDistC, zeroPsyEnergyC, tuDepth, (TextType)chromaId);
+                            uint64_t nullCostC = estimateNullCbfCost(zeroDistC, zeroEnergyC, tuDepth, (TextType)chromaId);
 
                             if (nullCostC < singleCostC)
                             {
@@ -3261,25 +3311,25 @@
                                 if (checkTransformSkipC)
                                     minCost[chromaId][tuIterator.section] = nullCostC;
                                 singleDist[chromaId][tuIterator.section] = zeroDistC;
-                                singlePsyEnergy[chromaId][tuIterator.section] = zeroPsyEnergyC;
+                                singleEnergy[chromaId][tuIterator.section] = zeroEnergyC;
                             }
                             else
                             {
                                 if (checkTransformSkipC)
                                     minCost[chromaId][tuIterator.section] = singleCostC;
                                 singleDist[chromaId][tuIterator.section] = nonZeroDistC;
-                                singlePsyEnergy[chromaId][tuIterator.section] = nonZeroPsyEnergyC;
+                                singleEnergy[chromaId][tuIterator.section] = nonZeroEnergyC;
                             }
                         }
                     }
                     else
                     {
                         if (checkTransformSkipC)
-                            minCost[chromaId][tuIterator.section] = estimateNullCbfCost(zeroDistC, zeroPsyEnergyC, tuDepthC, (TextType)chromaId);
+                            minCost[chromaId][tuIterator.section] = estimateNullCbfCost(zeroDistC, zeroEnergyC, tuDepthC, (TextType)chromaId);
                         primitives.cu[partSizeC].blockfill_s(curResiC, strideResiC, 0);
                         singleBits[chromaId][tuIterator.section] = 0;
                         singleDist[chromaId][tuIterator.section] = zeroDistC;
-                        singlePsyEnergy[chromaId][tuIterator.section] = zeroPsyEnergyC;
+                        singleEnergy[chromaId][tuIterator.section] = zeroEnergyC;
                     }
 
                     cu.setCbfPartRange(cbfFlag[chromaId][tuIterator.section] << tuDepth, (TextType)chromaId, absPartIdxC, tuIterator.absPartIdxStep);
@@ -3304,7 +3354,7 @@
         if (checkTransformSkipY)
         {
             sse_t nonZeroDistY = 0;
-            uint32_t nonZeroPsyEnergyY = 0;
+            uint32_t nonZeroEnergyY = 0;
             uint64_t singleCostY = MAX_INT64;
 
             m_entropyCoder.load(m_rqt[depth].rqtRoot);
@@ -3332,8 +3382,13 @@
 
                 if (m_rdCost.m_psyRd)
                 {
-                    nonZeroPsyEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, m_tsRecon, trSize);
-                    singleCostY = m_rdCost.calcPsyRdCost(nonZeroDistY, skipSingleBitsY, nonZeroPsyEnergyY);
+                    nonZeroEnergyY = m_rdCost.psyCost(partSize, fenc, fencYuv->m_size, m_tsRecon, trSize);
+                    singleCostY = m_rdCost.calcPsyRdCost(nonZeroDistY, skipSingleBitsY, nonZeroEnergyY);
+                }
+                else if(m_rdCost.m_ssimRd)
+                {
+                    nonZeroEnergyY = m_quant.ssimDistortion(cu, fenc, fencYuv->m_size, m_tsRecon, trSize, log2TrSize, TEXT_LUMA, absPartIdx);
+                    singleCostY = m_rdCost.calcSsimRdCost(nonZeroDistY, skipSingleBitsY, nonZeroEnergyY);
                 }
                 else
                     singleCostY = m_rdCost.calcRdCost(nonZeroDistY, skipSingleBitsY);
@@ -3344,7 +3399,7 @@
             else
             {
                 singleDist[TEXT_LUMA][0] = nonZeroDistY;
-                singlePsyEnergy[TEXT_LUMA][0] = nonZeroPsyEnergyY;
+                singleEnergy[TEXT_LUMA][0] = nonZeroEnergyY;
                 cbfFlag[TEXT_LUMA][0] = !!numSigTSkipY;
                 bestTransformMode[TEXT_LUMA][0] = 1;
                 if (m_param->limitTU)
@@ -3360,7 +3415,7 @@
         if (codeChroma && checkTransformSkipC)
         {
             sse_t nonZeroDistC = 0;
-            uint32_t nonZeroPsyEnergyC = 0;
+            uint32_t nonZeroEnergyC = 0;
             uint64_t singleCostC = MAX_INT64;
             uint32_t strideResiC = m_rqt[qtLayer].resiQtYuv.m_csize;
             uint32_t coeffOffsetC = coeffOffsetY >> (m_hChromaShift + m_vChromaShift);
@@ -3403,9 +3458,13 @@
                         nonZeroDistC = m_rdCost.scaleChromaDist(chromaId, primitives.cu[partSizeC].sse_pp(fenc, fencYuv->m_csize, m_tsRecon, trSizeC));
                         if (m_rdCost.m_psyRd)
                         {
-
-                            nonZeroPsyEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, m_tsRecon, trSizeC);
-                            singleCostC = m_rdCost.calcPsyRdCost(nonZeroDistC, singleBits[chromaId][tuIterator.section], nonZeroPsyEnergyC);
+                            nonZeroEnergyC = m_rdCost.psyCost(partSizeC, fenc, fencYuv->m_csize, m_tsRecon, trSizeC);
+                            singleCostC = m_rdCost.calcPsyRdCost(nonZeroDistC, singleBits[chromaId][tuIterator.section], nonZeroEnergyC);
+                        }
+                        else if(m_rdCost.m_ssimRd)
+                        {
+                            nonZeroEnergyC = m_quant.ssimDistortion(cu, fenc, mode.fencYuv->m_csize, m_tsRecon, trSizeC, log2TrSizeC, (TextType)chromaId, absPartIdxC);
+                            singleCostC = m_rdCost.calcSsimRdCost(nonZeroDistC, singleBits[chromaId][tuIterator.section], nonZeroEnergyC);
                         }
                         else
                             singleCostC = m_rdCost.calcRdCost(nonZeroDistC, singleBits[chromaId][tuIterator.section]);
@@ -3416,7 +3475,7 @@
                     else
                     {
                         singleDist[chromaId][tuIterator.section] = nonZeroDistC;
-                        singlePsyEnergy[chromaId][tuIterator.section] = nonZeroPsyEnergyC;
+                        singleEnergy[chromaId][tuIterator.section] = nonZeroEnergyC;
                         cbfFlag[chromaId][tuIterator.section] = !!numSigTSkipC;
                         bestTransformMode[chromaId][tuIterator.section] = 1;
                         uint32_t numCoeffC = 1 << (log2TrSizeC << 1);
@@ -3475,7 +3534,7 @@
         fullCost.bits = bSplitPresentFlag ? cbfBits + coeffBits : coeffBits;
 
         fullCost.distortion += singleDist[TEXT_LUMA][0];
-        fullCost.energy += singlePsyEnergy[TEXT_LUMA][0];// need to check we need to add chroma also
+        fullCost.energy += singleEnergy[TEXT_LUMA][0];// need to check we need to add chroma also
         for (uint32_t subTUIndex = 0; subTUIndex < 2; subTUIndex++)
         {
             fullCost.distortion += singleDist[TEXT_CHROMA_U][subTUIndex];
@@ -3484,6 +3543,8 @@
 
         if (m_rdCost.m_psyRd)
             fullCost.rdcost = m_rdCost.calcPsyRdCost(fullCost.distortion, fullCost.bits, fullCost.energy);
+        else if(m_rdCost.m_ssimRd)
+            fullCost.rdcost = m_rdCost.calcSsimRdCost(fullCost.distortion, fullCost.bits, fullCost.energy);
         else
             fullCost.rdcost = m_rdCost.calcRdCost(fullCost.distortion, fullCost.bits);
 
diff -r af10eaeb36cd -r 146036b4049c source/encoder/search.h
--- a/source/encoder/search.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/encoder/search.h	Wed Dec 28 19:12:02 2016 +0530
@@ -118,6 +118,7 @@
     uint64_t    sa8dCost;   // sum of partition sa8d distortion costs   (sa8d(fenc, pred) + lambda * bits)
     uint32_t    sa8dBits;   // signal bits used in sa8dCost calculation
     uint32_t    psyEnergy;  // sum of partition psycho-visual energy difference
+    uint32_t    ssimEnergy;
     sse_t   resEnergy;  // sum of partition residual energy after motion prediction
     sse_t   lumaDistortion;
     sse_t   chromaDistortion;
@@ -132,6 +133,7 @@
         sa8dCost = 0;
         sa8dBits = 0;
         psyEnergy = 0;
+        ssimEnergy = 0;
         resEnergy = 0;
         lumaDistortion = 0;
         chromaDistortion = 0;
@@ -147,6 +149,7 @@
         sa8dCost += subMode.sa8dCost;
         sa8dBits += subMode.sa8dBits;
         psyEnergy += subMode.psyEnergy;
+        ssimEnergy += subMode.ssimEnergy;
         resEnergy += subMode.resEnergy;
         lumaDistortion += subMode.lumaDistortion;
         chromaDistortion += subMode.chromaDistortion;
@@ -390,7 +393,7 @@
         Entropy rqtStore[NUM_SUBPART];
     } m_cacheTU;
 
-    uint64_t estimateNullCbfCost(sse_t dist, uint32_t psyEnergy, uint32_t tuDepth, TextType compId);
+    uint64_t estimateNullCbfCost(sse_t dist, uint32_t energy, uint32_t tuDepth, TextType compId);
     bool     splitTU(Mode& mode, const CUGeom& cuGeom, uint32_t absPartIdx, uint32_t tuDepth, ShortYuv& resiYuv, Cost& splitCost, const uint32_t depthRange[2], int32_t splitMore);
     void     estimateResidualQT(Mode& mode, const CUGeom& cuGeom, uint32_t absPartIdx, uint32_t depth, ShortYuv& resiYuv, Cost& costs, const uint32_t depthRange[2], int32_t splitMore = -1);
 
@@ -430,7 +433,9 @@
     // get most probable luma modes for CU part, and bit cost of all non mpm modes
     uint32_t getIntraRemModeBits(CUData & cu, uint32_t absPartIdx, uint32_t mpmModes[3], uint64_t& mpms) const;
 
-    void updateModeCost(Mode& m) const { m.rdCost = m_rdCost.m_psyRd ? m_rdCost.calcPsyRdCost(m.distortion, m.totalBits, m.psyEnergy) : m_rdCost.calcRdCost(m.distortion, m.totalBits); }
+    void updateModeCost(Mode& m) const { m.rdCost = m_rdCost.m_psyRd ? m_rdCost.calcPsyRdCost(m.distortion, m.totalBits, m.psyEnergy)
+                                                : (m_rdCost.m_ssimRd ? m_rdCost.calcSsimRdCost(m.distortion, m.totalBits, m.ssimEnergy) 
+                                                : m_rdCost.calcRdCost(m.distortion, m.totalBits)); }
 };
 }
 
diff -r af10eaeb36cd -r 146036b4049c source/x265.h
--- a/source/x265.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/x265.h	Wed Dec 28 19:12:02 2016 +0530
@@ -1058,6 +1058,11 @@
      * the encoder must perform. Default X265_ANALYSIS_OFF */
     int       analysisMode;
 
+    /* SSIM based RDO, based on residual divisive normalization scheme. Used for mode
+    * selection during analysis of CTUs, can achieve significant gain in terms of 
+    * objective quality metrics SSIM and PSNR */
+    int       bSsimRd;
+
     /* Filename for analysisMode save/load. Default name is "x265_analysis.dat" */
     const char* analysisFileName;
 
diff -r af10eaeb36cd -r 146036b4049c source/x265cli.h
--- a/source/x265cli.h	Wed Dec 28 10:17:08 2016 +0530
+++ b/source/x265cli.h	Wed Dec 28 19:12:02 2016 +0530
@@ -256,6 +256,8 @@
     { "analyze-src-pics", no_argument, NULL, 0 },
     { "no-analyze-src-pics", no_argument, NULL, 0 },
     { "slices",         required_argument, NULL, 0 },
+    { "ssim-rd",      no_argument, NULL, 0 },
+    { "no-ssim-rd",   no_argument, NULL, 0 },
     { 0, 0, 0, 0 },
     { 0, 0, 0, 0 },
     { 0, 0, 0, 0 },
@@ -340,6 +342,7 @@
     H0("   --[no-]psy-rd <0..5.0>        Strength of psycho-visual rate distortion optimization, 0 to disable. Default %.1f\n", param->psyRd);
     H0("   --[no-]rdoq-level <0|1|2>     Level of RDO in quantization 0:none, 1:levels, 2:levels & coding groups. Default %d\n", param->rdoqLevel);
     H0("   --[no-]psy-rdoq <0..50.0>     Strength of psycho-visual optimization in RDO quantization, 0 to disable. Default %.1f\n", param->psyRdoq);
+    H0("   --[no-]ssim-rd                Enable ssim rate distortion optimization, 0 to disable. Default %.1f\n", OPT(param->bSsimRd));
     H0("   --[no-]rd-refine              Enable QP based RD refinement for rd levels 5 and 6. Default %s\n", OPT(param->bEnableRdRefine));
     H0("   --[no-]early-skip             Enable early SKIP detection. Default %s\n", OPT(param->bEnableEarlySkip));
     H0("   --[no-]rskip                  Enable early exit from recursion. Default %s\n", OPT(param->bEnableRecursionSkip));



More information about the x265-devel mailing list